import numpy as np
import random
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.colors import ListedColormap
import matplotlib.gridspec as gridspec
import scipy
from scipy.stats import multivariate_normal
from scipy.sparse import linalg
from collections import defaultdict
from tqdm import tqdm
import networkx as nx
$\cdot$ specification about chunk order:
The order of the chunk might not be continuent from the start to the end. However, the order within each section is continuent and interrupting chunk number are added print-puts based on running chunks.
$\cdot$ About the use of sklearn in 2.1:
There is an import of sklearn, which has been used in 2.1 k-mean clustering. This is only for purpose of result checking. All the code in this notebook for assessment is written by elementary python functions or allowed packages.
Enjoy!
In this section, an MLP is built from scratch. For given activation functions, SGD is used as the optimization function and KL divergence as the loss function. Then the learning rate is adjusted to find the optimal one, and then the layer width is adjusted to find its effect on the model performance. As MLP wraps thousands of millions of parameters, dropout is used to regularize the network and to contruct a 'sub-network'. And the effect of such regularization on model performance isdiscussed through training nad test losses and accuracies.
A probablistic substitute for MLP is deep Gaussian process. In the last bit of the seesion, the histogram of the outcome of the first layer are plotted for both drop-out case and non-dropout case. And the effect of dropout is discussed from this perspective.
# load data and do some data processing
MNIST_train = pd.read_csv("MNIST_train.csv")
MNIST_test = pd.read_csv("MNIST_test.csv")
display(MNIST_train.head())
print("Shape of MNIST_train: ", MNIST_train.shape)
print("Shape of MNIST_test: ", MNIST_test.shape)
# convert to numpy array
MNIST_train = MNIST_train.to_numpy()
MNIST_test = MNIST_test.to_numpy()
# target-predictor split
x_train, y_train = MNIST_train[:,1:]/255, MNIST_train[:,0]
x_test, y_test = MNIST_test[:,1:]/255, MNIST_test[:,0]
print("Shape of x_train: ", x_train.shape)
print("Shape of y_train: ", y_train.shape)
| label | 1x1 | 1x2 | 1x3 | 1x4 | 1x5 | 1x6 | 1x7 | 1x8 | 1x9 | ... | 28x19 | 28x20 | 28x21 | 28x22 | 28x23 | 28x24 | 28x25 | 28x26 | 28x27 | 28x28 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 9 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 2 | 7 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 3 | 8 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 785 columns
Shape of MNIST_train: (6000, 785) Shape of MNIST_test: (1000, 785) Shape of x_train: (6000, 784) Shape of y_train: (6000,)
MNIST_test[:, 1:][0]
array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 242, 205,
19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 57, 242,
253, 253, 166, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13,
241, 254, 253, 253, 173, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 56,
95, 216, 253, 254, 253, 253, 241, 87, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
15, 195, 253, 253, 253, 254, 253, 253, 253, 239, 88, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 83, 253, 253, 253, 253, 254, 253, 253, 253, 253, 239,
88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 68, 233, 253, 253, 253, 232, 214, 213, 213, 226,
253, 253, 241, 86, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 23, 206, 253, 253, 250, 148, 44, 0, 0,
0, 30, 128, 246, 253, 243, 68, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 214, 253, 253, 253, 219, 0, 0,
0, 0, 0, 0, 0, 110, 253, 253, 184, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 50, 235, 253, 253, 232, 44,
0, 0, 0, 0, 0, 0, 0, 26, 211, 253, 240, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 254, 254, 254,
80, 0, 0, 0, 0, 0, 0, 0, 0, 26, 213, 255, 241,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 103, 252, 253,
225, 47, 4, 0, 0, 0, 0, 0, 0, 0, 0, 54, 253,
253, 240, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 228,
253, 253, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 181,
233, 253, 253, 120, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 241, 253, 228, 36, 0, 0, 0, 0, 0, 0, 0, 0,
69, 245, 253, 253, 188, 17, 0, 0, 0, 0, 0, 0, 0,
0, 0, 96, 251, 253, 67, 0, 0, 0, 0, 0, 0, 0,
0, 111, 217, 253, 253, 243, 69, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 121, 253, 253, 68, 0, 0, 0, 0, 0,
0, 54, 181, 247, 253, 253, 242, 98, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 102, 251, 253, 228, 58, 9, 41,
55, 97, 112, 255, 253, 253, 253, 243, 98, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 227, 253, 253, 229,
180, 253, 253, 253, 253, 255, 253, 243, 184, 69, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 155,
251, 253, 253, 253, 253, 253, 253, 241, 240, 81, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 102, 225, 253, 253, 253, 246, 120, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0], dtype=int64)
# input layer
def flatten(x):
return np.array([i.flatten() for i in x])
# softplus function
def softplus(x, beta=1):
sp = np.log(1 + np.exp(beta*x)) / beta
return sp
# derivative of softplus
def softplus_deriv(x, beta=1):
return np.exp(beta*x) / (1 + np.exp(beta*x))
# softmax function
def softmax(x):
y = np.exp(x) / np.sum(np.exp(x), axis=1)[..., np.newaxis]
return y
# KL divergence loss function
def kl_loss(p, q):
# Replace all zeros with a very small float (considering underflow issues)
p[p == 0] = 1e-10
q[q == 0] = 1e-10
# return np.sum(p * np.log(p / q))/len(q)
return np.sum(p * (np.log(p + 1e-8) - np.log(q + 1e-8)), axis=1)
# accuracy
def mlp_accuracy(y_pred, y_test):
return np.mean(y_pred == y_test)
# prob vector to one single prediction
def mlp_prediction(y_pred):
return np.argmax(y_pred, axis=1)
# compute the output error
def output_error(y_batch, a):
y_pred = softmax(a)
return y_pred - y_batch
# used to transform labels to probability distribution
def one_hot(Y):
return np.eye(10)[Y]
# considering overfitting problems, use regularization technique
def dropout_mask(n, dropout_prob):
return np.random.binomial(1, (1 - dropout_prob), size=n) / (1 - dropout_prob)
initialize parameters
# initialize the parameters: 3 hidden layers, each with 200 neurons; output layer with 10 neurons, one for each class
# use Glorot initialisation to initialize weights and bias
def init_params(width=200):
var0 = 2. / (784 + width)
W1 = np.random.randn(784, width) * np.sqrt(var0)
b1 = np.zeros(width)
var1 = 2. / (width + width)
W2 = np.random.randn(width, width) * np.sqrt(var1)
b2 = np.zeros(width)
var2 = 2. / (width + width)
W3 = np.random.randn(width, width) * np.sqrt(var2)
b3 = np.zeros(width)
var3 = 2. / (10 + width)
W4 = np.random.randn(width, 10) * np.sqrt(var3)
b4 = np.zeros(10)
return {'W1':W1,'W2':W2,'W3':W3,'W4':W4}, {'b1':b1,'b2':b2,'b3':b3,'b4':b4}
forward propogation
def forward_prop(x, weights, bias, dropout_prob=0):
W1, W2, W3, W4 = weights.values()
b1,b2,b3,b4 = bias.values()
# input
Z1 = np.dot(x,W1) + b1
d1 = dropout_mask(Z1.shape,dropout_prob)
A1 = softplus(Z1)
A1 *= d1
# hidden
Z2 = np.dot(A1,W2) + b2
d2 = dropout_mask(Z2.shape,dropout_prob)
A2 = softplus(Z2)
A2 *= d2
Z3 = np.dot(A2,W3) + b3
d3 = dropout_mask(Z3.shape,dropout_prob)
A3 = softplus(Z3)
A3 *= d3
# output
Z4 = np.dot(A3,W4) + b4
A4 = softmax(Z4)
return {'Z1': Z1,'Z2': Z2,'Z3': Z3,'Z4': Z4 },{'A1': A1, 'A2': A2,'A3': A3,'A4': A4}, {'d1':d1,'d2':d2,'d3':d3}
backward propagation
def backward_prop(x_batch, y_batch, outputs, weights, bias):
m = y_batch.shape[0]
Z1,Z2,Z3,Z4 = outputs[0].values()
A1,A2,A3,A4 = outputs[1].values()
d1,d2,d3 = outputs[2].values()
W1, W2, W3, W4 = weights.values()
b1,b2,b3,b4 = bias.values()
dZ4 = A4 - one_hot(y_batch)
dW4 = 1 / m * np.dot(dZ4.T,A3)
db4 = 1 / m * np.sum(dZ4)
dZ3 = np.dot(dZ4,W4.T) * softplus_deriv(Z3)
dZ3 *= d3
dW3 = 1 / m * np.dot(dZ3.T,A2)
db3 = 1 / m * np.sum(dZ3)
dZ2 = np.dot(dZ3,W3.T)* softplus_deriv(Z2)
dZ2 *= d2
dW2 = 1 / m * np.dot(dZ2.T,A1)
db2 = 1 / m * np.sum(dZ2)
dZ1 = np.dot(dZ2,W2.T) * softplus_deriv(Z1)
dZ1 *= d1
dW1 = 1 / m * np.dot(dZ1.T,x_batch)
db1 = 1 / m * np.sum(dZ1)
return {'dW4': dW4, 'db4': db4, 'dW3': dW3, 'db3': db3, 'dW2': dW2, 'db2': db2, 'dW1': dW1, 'db1': db1}
update the parameters (SGD)
def update_params(weights, bias, grads, lr):
weights['W1'] -= lr * grads['dW1'].T
weights['W2'] -= lr * grads['dW2'].T
weights['W3'] -= lr * grads['dW3'].T
weights['W4'] -= lr * grads['dW4'].T
bias['b1'] -= lr * grads['db1']
bias['b2'] -= lr * grads['db2']
bias['b3'] -= lr * grads['db3']
bias['b4'] -= lr * grads['db4']
return weights, bias
MLP training
def train(x_train, y_train, x_test, y_test, lr, epochs, batch_size, width=200, dropout_prob=0):
# initialization
weights, bias = init_params(width=width)
training_losses = []
training_accuracies = []
test_losses = []
test_accuracies = []
# sgd implementation
for epoch in range(1, epochs+1):
# shuffle
indices = np.random.permutation(len(x_train))
x_train = x_train[indices]
y_train = y_train[indices]
n_obs = x_train.shape[0]
# a process bar to track the process
progress_bar = tqdm(range(0, n_obs, batch_size), desc = f"Epoch {epoch}", unit = "batch", colour='green')
count = 0
# iterate through batches
for start in progress_bar:
stop = start + batch_size
x_batch, y_batch = x_train[start:stop], y_train[start:stop]
forward_outputs = forward_prop(x_batch, weights, bias,dropout_prob)
grads = backward_prop(x_batch, y_batch, forward_outputs, weights, bias)
weights, bias = update_params(weights, bias, grads, lr)
count += 1
if count == len(progress_bar):
# forward propogation
forward_outputs_train = forward_prop(x_train, weights, bias)
predictions_train = mlp_prediction(forward_outputs_train[1]['A4'])
training_accuracy = mlp_accuracy(predictions_train, y_train)
training_loss = np.mean(kl_loss(one_hot(y_train), one_hot(predictions_train)))
# forward propagation on test set
forward_outputs_test = forward_prop(x_test, weights, bias)
predictions_test = mlp_prediction(forward_outputs_test[1]['A4'])
test_accuracy = mlp_accuracy(predictions_test, y_test)
test_loss = np.mean(kl_loss(one_hot(y_test), one_hot(predictions_test)))
# store on the lists
training_losses.append(training_loss)
training_accuracies.append(training_accuracy)
test_losses.append(test_loss)
test_accuracies.append(test_accuracy)
progress_bar.set_postfix(train_loss = training_loss, train_accuracy = training_accuracy, test_loss = test_loss, test_accuracy = test_accuracy)
return weights, bias, training_losses, training_accuracies,test_losses,test_accuracies
lr = 0.18
epochs = 40
batch_size = 128
w,b, train_loss, train_acc, test_loss,test_acc = train(x_train, y_train, x_test, y_test, lr, epochs, batch_size)
Epoch 1: 100%|█| 47/47 [00:01<00:00, 36.78batch/s, test_accuracy=0.242, test_loss=14, train_accuracy=0.228, train_loss= Epoch 2: 100%|█| 47/47 [00:01<00:00, 38.31batch/s, test_accuracy=0.588, test_loss=7.59, train_accuracy=0.596, train_los Epoch 3: 100%|█| 47/47 [00:01<00:00, 40.14batch/s, test_accuracy=0.677, test_loss=5.95, train_accuracy=0.679, train_los Epoch 4: 100%|█| 47/47 [00:01<00:00, 44.26batch/s, test_accuracy=0.82, test_loss=3.31, train_accuracy=0.815, train_loss Epoch 5: 100%|█| 47/47 [00:01<00:00, 36.73batch/s, test_accuracy=0.831, test_loss=3.11, train_accuracy=0.823, train_los Epoch 6: 100%|█| 47/47 [00:01<00:00, 40.75batch/s, test_accuracy=0.87, test_loss=2.39, train_accuracy=0.86, train_loss= Epoch 7: 100%|█| 47/47 [00:01<00:00, 39.83batch/s, test_accuracy=0.872, test_loss=2.36, train_accuracy=0.868, train_los Epoch 8: 100%|█| 47/47 [00:00<00:00, 47.74batch/s, test_accuracy=0.878, test_loss=2.25, train_accuracy=0.881, train_los Epoch 9: 100%|█| 47/47 [00:00<00:00, 50.40batch/s, test_accuracy=0.909, test_loss=1.68, train_accuracy=0.901, train_los Epoch 10: 100%|█| 47/47 [00:00<00:00, 49.27batch/s, test_accuracy=0.923, test_loss=1.42, train_accuracy=0.911, train_lo Epoch 11: 100%|█| 47/47 [00:00<00:00, 51.39batch/s, test_accuracy=0.915, test_loss=1.56, train_accuracy=0.91, train_los Epoch 12: 100%|█| 47/47 [00:00<00:00, 50.68batch/s, test_accuracy=0.917, test_loss=1.53, train_accuracy=0.926, train_lo Epoch 13: 100%|█| 47/47 [00:01<00:00, 40.70batch/s, test_accuracy=0.905, test_loss=1.75, train_accuracy=0.906, train_lo Epoch 14: 100%|█| 47/47 [00:01<00:00, 38.44batch/s, test_accuracy=0.881, test_loss=2.19, train_accuracy=0.879, train_lo Epoch 15: 100%|█| 47/47 [00:01<00:00, 43.99batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.922, train_lo Epoch 16: 100%|█| 47/47 [00:01<00:00, 45.85batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.939, train_los Epoch 17: 100%|█| 47/47 [00:00<00:00, 50.67batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.925, train_lo Epoch 18: 100%|█| 47/47 [00:00<00:00, 52.62batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.943, train_lo Epoch 19: 100%|█| 47/47 [00:00<00:00, 51.45batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.945, train_lo Epoch 20: 100%|█| 47/47 [00:00<00:00, 50.95batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.945, train_lo Epoch 21: 100%|█| 47/47 [00:00<00:00, 50.54batch/s, test_accuracy=0.905, test_loss=1.75, train_accuracy=0.913, train_lo Epoch 22: 100%|█| 47/47 [00:00<00:00, 50.55batch/s, test_accuracy=0.908, test_loss=1.69, train_accuracy=0.933, train_lo Epoch 23: 100%|█| 47/47 [00:01<00:00, 40.33batch/s, test_accuracy=0.921, test_loss=1.45, train_accuracy=0.941, train_lo Epoch 24: 100%|█| 47/47 [00:01<00:00, 44.73batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.956, train_los Epoch 25: 100%|█| 47/47 [00:01<00:00, 44.52batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.955, train_lo Epoch 26: 100%|█| 47/47 [00:01<00:00, 45.24batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.96, train_los Epoch 27: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.961, train_lo Epoch 28: 100%|█| 47/47 [00:00<00:00, 48.78batch/s, test_accuracy=0.945, test_loss=1.01, train_accuracy=0.964, train_lo Epoch 29: 100%|█| 47/47 [00:01<00:00, 43.35batch/s, test_accuracy=0.923, test_loss=1.42, train_accuracy=0.94, train_los Epoch 30: 100%|█| 47/47 [00:00<00:00, 47.49batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.968, train_lo Epoch 31: 100%|█| 47/47 [00:01<00:00, 44.75batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.962, train_los Epoch 32: 100%|█| 47/47 [00:00<00:00, 47.62batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.915, train_loss Epoch 33: 100%|█| 47/47 [00:01<00:00, 38.86batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.975, train_lo Epoch 34: 100%|█| 47/47 [00:01<00:00, 45.31batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.969, train_los Epoch 35: 100%|█| 47/47 [00:01<00:00, 42.03batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.976, train_lo Epoch 36: 100%|█| 47/47 [00:01<00:00, 46.42batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.97, train_los Epoch 37: 100%|█| 47/47 [00:01<00:00, 37.34batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.974, train_lo Epoch 38: 100%|█| 47/47 [00:01<00:00, 44.50batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.973, train_lo Epoch 39: 100%|█| 47/47 [00:01<00:00, 42.19batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.973, train_lo Epoch 40: 100%|█| 47/47 [00:01<00:00, 45.57batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.988, train_lo
# print train_loss
train_loss
[14.210015430141066, 7.441003545258493, 5.9037075550834395, 3.4151904931433825, 3.258699284562688, 2.577502258976138, 2.427147960535863, 2.193945375199927, 1.8287992218449738, 1.629349642281344, 1.6631026480536508, 1.365462506243311, 1.7275402045280541, 2.2307668360424424, 1.4329685177879243, 1.119986100626536, 1.3808047815943594, 1.0555485441521326, 1.0217955383798258, 1.0064532630287777, 1.6048020017196665, 1.2243135730136654, 1.0954384600648586, 0.816209048675777, 0.8254144138864059, 0.7425661269907444, 0.7180184864290668, 0.6627862951652925, 1.1107807354159072, 0.5830064633398406, 0.7026762110780186, 1.5679805408771506, 0.46640517067187254, 0.5768695531994212, 0.4357206199697756, 0.5492534575675341, 0.4848159010931306, 0.500158176444179, 0.49095281123355, 0.23013413026572657]
# set different learning rates
learning_rates = np.logspace(-4, 0, 7)
epochs = 40
batch_size = 128
train_losses = []
test_losses = []
for learning_rate in learning_rates:
print(f'Training MLP for learning rate = {learning_rate}')
w,b, train_loss, train_acc, test_loss,test_acc = train(x_train, y_train, x_test, y_test, learning_rate, epochs,batch_size)
# record the final loss
train_losses.append(train_loss[-1])
test_losses.append(test_loss[-1])
Training MLP for learning rate = 0.0001
Epoch 1: 100%|█| 47/47 [00:01<00:00, 42.42batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 2: 100%|█| 47/47 [00:01<00:00, 44.39batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 3: 100%|█| 47/47 [00:01<00:00, 34.45batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 4: 100%|█| 47/47 [00:03<00:00, 11.90batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 5: 100%|█| 47/47 [00:02<00:00, 20.31batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 6: 100%|█| 47/47 [00:01<00:00, 25.36batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 7: 100%|█| 47/47 [00:02<00:00, 20.49batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 8: 100%|█| 47/47 [00:02<00:00, 22.41batch/s, test_accuracy=0.098, test_loss=16.6, train_accuracy=0.0985, train_lo Epoch 9: 100%|█| 47/47 [00:01<00:00, 30.30batch/s, test_accuracy=0.095, test_loss=16.7, train_accuracy=0.0968, train_lo Epoch 10: 100%|█| 47/47 [00:04<00:00, 11.03batch/s, test_accuracy=0.081, test_loss=16.9, train_accuracy=0.089, train_lo Epoch 11: 100%|█| 47/47 [00:02<00:00, 19.39batch/s, test_accuracy=0.072, test_loss=17.1, train_accuracy=0.08, train_los Epoch 12: 100%|█| 47/47 [00:01<00:00, 24.93batch/s, test_accuracy=0.063, test_loss=17.3, train_accuracy=0.0705, train_l Epoch 13: 100%|█| 47/47 [00:02<00:00, 19.62batch/s, test_accuracy=0.064, test_loss=17.2, train_accuracy=0.0697, train_l Epoch 14: 100%|█| 47/47 [00:01<00:00, 26.84batch/s, test_accuracy=0.06, test_loss=17.3, train_accuracy=0.0688, train_lo Epoch 15: 100%|█| 47/47 [00:01<00:00, 32.97batch/s, test_accuracy=0.063, test_loss=17.3, train_accuracy=0.0687, train_l Epoch 16: 100%|█| 47/47 [00:01<00:00, 32.97batch/s, test_accuracy=0.064, test_loss=17.2, train_accuracy=0.071, train_lo Epoch 17: 100%|█| 47/47 [00:01<00:00, 33.78batch/s, test_accuracy=0.07, test_loss=17.1, train_accuracy=0.0745, train_lo Epoch 18: 100%|█| 47/47 [00:01<00:00, 34.64batch/s, test_accuracy=0.076, test_loss=17, train_accuracy=0.0775, train_los Epoch 19: 100%|█| 47/47 [00:01<00:00, 33.60batch/s, test_accuracy=0.079, test_loss=17, train_accuracy=0.081, train_loss Epoch 20: 100%|█| 47/47 [00:01<00:00, 33.93batch/s, test_accuracy=0.08, test_loss=16.9, train_accuracy=0.082, train_los Epoch 21: 100%|█| 47/47 [00:01<00:00, 31.98batch/s, test_accuracy=0.085, test_loss=16.8, train_accuracy=0.0843, train_l Epoch 22: 100%|█| 47/47 [00:01<00:00, 34.69batch/s, test_accuracy=0.091, test_loss=16.7, train_accuracy=0.0872, train_l Epoch 23: 100%|█| 47/47 [00:01<00:00, 33.70batch/s, test_accuracy=0.094, test_loss=16.7, train_accuracy=0.0897, train_l Epoch 24: 100%|█| 47/47 [00:01<00:00, 34.21batch/s, test_accuracy=0.097, test_loss=16.6, train_accuracy=0.0922, train_l Epoch 25: 100%|█| 47/47 [00:01<00:00, 32.84batch/s, test_accuracy=0.099, test_loss=16.6, train_accuracy=0.0942, train_l Epoch 26: 100%|█| 47/47 [00:01<00:00, 32.82batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.0962, train_los Epoch 27: 100%|█| 47/47 [00:01<00:00, 33.49batch/s, test_accuracy=0.101, test_loss=16.6, train_accuracy=0.0973, train_l Epoch 28: 100%|█| 47/47 [00:01<00:00, 32.70batch/s, test_accuracy=0.105, test_loss=16.5, train_accuracy=0.0982, train_l Epoch 29: 100%|█| 47/47 [00:01<00:00, 28.34batch/s, test_accuracy=0.106, test_loss=16.5, train_accuracy=0.0998, train_l Epoch 30: 100%|█| 47/47 [00:01<00:00, 27.95batch/s, test_accuracy=0.108, test_loss=16.4, train_accuracy=0.101, train_lo Epoch 31: 100%|█| 47/47 [00:01<00:00, 33.67batch/s, test_accuracy=0.108, test_loss=16.4, train_accuracy=0.102, train_lo Epoch 32: 100%|█| 47/47 [00:01<00:00, 36.49batch/s, test_accuracy=0.109, test_loss=16.4, train_accuracy=0.103, train_lo Epoch 33: 100%|█| 47/47 [00:01<00:00, 33.64batch/s, test_accuracy=0.109, test_loss=16.4, train_accuracy=0.104, train_lo Epoch 34: 100%|█| 47/47 [00:01<00:00, 30.15batch/s, test_accuracy=0.11, test_loss=16.4, train_accuracy=0.106, train_los Epoch 35: 100%|█| 47/47 [00:01<00:00, 32.09batch/s, test_accuracy=0.112, test_loss=16.3, train_accuracy=0.105, train_lo Epoch 36: 100%|█| 47/47 [00:01<00:00, 30.67batch/s, test_accuracy=0.111, test_loss=16.4, train_accuracy=0.105, train_lo Epoch 37: 100%|█| 47/47 [00:01<00:00, 34.75batch/s, test_accuracy=0.11, test_loss=16.4, train_accuracy=0.106, train_los Epoch 38: 100%|█| 47/47 [00:01<00:00, 33.78batch/s, test_accuracy=0.11, test_loss=16.4, train_accuracy=0.107, train_los Epoch 39: 100%|█| 47/47 [00:01<00:00, 35.49batch/s, test_accuracy=0.11, test_loss=16.4, train_accuracy=0.107, train_los Epoch 40: 100%|█| 47/47 [00:01<00:00, 34.25batch/s, test_accuracy=0.112, test_loss=16.3, train_accuracy=0.109, train_lo
Training MLP for learning rate = 0.00046415888336127773
Epoch 1: 100%|█| 47/47 [00:01<00:00, 35.19batch/s, test_accuracy=0.099, test_loss=16.6, train_accuracy=0.0983, train_lo Epoch 2: 100%|█| 47/47 [00:01<00:00, 35.87batch/s, test_accuracy=0.094, test_loss=16.7, train_accuracy=0.0948, train_lo Epoch 3: 100%|█| 47/47 [00:01<00:00, 35.33batch/s, test_accuracy=0.094, test_loss=16.7, train_accuracy=0.0918, train_lo Epoch 4: 100%|█| 47/47 [00:01<00:00, 35.57batch/s, test_accuracy=0.106, test_loss=16.5, train_accuracy=0.105, train_los Epoch 5: 100%|█| 47/47 [00:01<00:00, 35.37batch/s, test_accuracy=0.109, test_loss=16.4, train_accuracy=0.117, train_los Epoch 6: 100%|█| 47/47 [00:01<00:00, 35.05batch/s, test_accuracy=0.123, test_loss=16.1, train_accuracy=0.122, train_los Epoch 7: 100%|█| 47/47 [00:01<00:00, 37.60batch/s, test_accuracy=0.136, test_loss=15.9, train_accuracy=0.122, train_los Epoch 8: 100%|█| 47/47 [00:01<00:00, 35.84batch/s, test_accuracy=0.136, test_loss=15.9, train_accuracy=0.132, train_los Epoch 9: 100%|█| 47/47 [00:01<00:00, 33.78batch/s, test_accuracy=0.14, test_loss=15.8, train_accuracy=0.142, train_loss Epoch 10: 100%|█| 47/47 [00:01<00:00, 35.64batch/s, test_accuracy=0.148, test_loss=15.7, train_accuracy=0.154, train_lo Epoch 11: 100%|█| 47/47 [00:01<00:00, 35.24batch/s, test_accuracy=0.157, test_loss=15.5, train_accuracy=0.162, train_lo Epoch 12: 100%|█| 47/47 [00:01<00:00, 35.00batch/s, test_accuracy=0.165, test_loss=15.4, train_accuracy=0.174, train_lo Epoch 13: 100%|█| 47/47 [00:01<00:00, 36.69batch/s, test_accuracy=0.165, test_loss=15.4, train_accuracy=0.176, train_lo Epoch 14: 100%|█| 47/47 [00:01<00:00, 36.05batch/s, test_accuracy=0.17, test_loss=15.3, train_accuracy=0.185, train_los Epoch 15: 100%|█| 47/47 [00:01<00:00, 35.56batch/s, test_accuracy=0.172, test_loss=15.2, train_accuracy=0.192, train_lo Epoch 16: 100%|█| 47/47 [00:01<00:00, 33.93batch/s, test_accuracy=0.183, test_loss=15, train_accuracy=0.199, train_loss Epoch 17: 100%|█| 47/47 [00:01<00:00, 34.94batch/s, test_accuracy=0.186, test_loss=15, train_accuracy=0.204, train_loss Epoch 18: 100%|█| 47/47 [00:01<00:00, 33.82batch/s, test_accuracy=0.198, test_loss=14.8, train_accuracy=0.213, train_lo Epoch 19: 100%|█| 47/47 [00:01<00:00, 35.91batch/s, test_accuracy=0.203, test_loss=14.7, train_accuracy=0.219, train_lo Epoch 20: 100%|█| 47/47 [00:01<00:00, 33.81batch/s, test_accuracy=0.207, test_loss=14.6, train_accuracy=0.227, train_lo Epoch 21: 100%|█| 47/47 [00:01<00:00, 34.96batch/s, test_accuracy=0.22, test_loss=14.4, train_accuracy=0.233, train_los Epoch 22: 100%|█| 47/47 [00:01<00:00, 36.36batch/s, test_accuracy=0.227, test_loss=14.2, train_accuracy=0.239, train_lo Epoch 23: 100%|█| 47/47 [00:01<00:00, 44.30batch/s, test_accuracy=0.23, test_loss=14.2, train_accuracy=0.245, train_los Epoch 24: 100%|█| 47/47 [00:01<00:00, 46.95batch/s, test_accuracy=0.24, test_loss=14, train_accuracy=0.253, train_loss= Epoch 25: 100%|█| 47/47 [00:01<00:00, 43.90batch/s, test_accuracy=0.249, test_loss=13.8, train_accuracy=0.258, train_lo Epoch 26: 100%|█| 47/47 [00:01<00:00, 45.71batch/s, test_accuracy=0.259, test_loss=13.6, train_accuracy=0.265, train_lo Epoch 27: 100%|█| 47/47 [00:01<00:00, 46.80batch/s, test_accuracy=0.26, test_loss=13.6, train_accuracy=0.27, train_loss Epoch 28: 100%|█| 47/47 [00:01<00:00, 45.00batch/s, test_accuracy=0.268, test_loss=13.5, train_accuracy=0.276, train_lo Epoch 29: 100%|█| 47/47 [00:01<00:00, 44.27batch/s, test_accuracy=0.28, test_loss=13.3, train_accuracy=0.286, train_los Epoch 30: 100%|█| 47/47 [00:01<00:00, 46.29batch/s, test_accuracy=0.288, test_loss=13.1, train_accuracy=0.298, train_lo Epoch 31: 100%|█| 47/47 [00:00<00:00, 47.53batch/s, test_accuracy=0.291, test_loss=13.1, train_accuracy=0.302, train_lo Epoch 32: 100%|█| 47/47 [00:00<00:00, 47.68batch/s, test_accuracy=0.303, test_loss=12.8, train_accuracy=0.316, train_lo Epoch 33: 100%|█| 47/47 [00:01<00:00, 45.18batch/s, test_accuracy=0.302, test_loss=12.9, train_accuracy=0.318, train_lo Epoch 34: 100%|█| 47/47 [00:00<00:00, 48.68batch/s, test_accuracy=0.313, test_loss=12.6, train_accuracy=0.325, train_lo Epoch 35: 100%|█| 47/47 [00:00<00:00, 48.78batch/s, test_accuracy=0.328, test_loss=12.4, train_accuracy=0.333, train_lo Epoch 36: 100%|█| 47/47 [00:01<00:00, 43.71batch/s, test_accuracy=0.326, test_loss=12.4, train_accuracy=0.342, train_lo Epoch 37: 100%|█| 47/47 [00:01<00:00, 43.30batch/s, test_accuracy=0.337, test_loss=12.2, train_accuracy=0.346, train_lo Epoch 38: 100%|█| 47/47 [00:01<00:00, 47.00batch/s, test_accuracy=0.343, test_loss=12.1, train_accuracy=0.354, train_lo Epoch 39: 100%|█| 47/47 [00:01<00:00, 46.10batch/s, test_accuracy=0.348, test_loss=12, train_accuracy=0.364, train_loss Epoch 40: 100%|█| 47/47 [00:01<00:00, 45.93batch/s, test_accuracy=0.353, test_loss=11.9, train_accuracy=0.367, train_lo
Training MLP for learning rate = 0.002154434690031882
Epoch 1: 100%|█| 47/47 [00:00<00:00, 47.78batch/s, test_accuracy=0.079, test_loss=17, train_accuracy=0.0878, train_loss Epoch 2: 100%|█| 47/47 [00:00<00:00, 48.03batch/s, test_accuracy=0.134, test_loss=15.9, train_accuracy=0.125, train_los Epoch 3: 100%|█| 47/47 [00:01<00:00, 46.49batch/s, test_accuracy=0.157, test_loss=15.5, train_accuracy=0.148, train_los Epoch 4: 100%|█| 47/47 [00:00<00:00, 48.24batch/s, test_accuracy=0.229, test_loss=14.2, train_accuracy=0.223, train_los Epoch 5: 100%|█| 47/47 [00:01<00:00, 45.99batch/s, test_accuracy=0.259, test_loss=13.6, train_accuracy=0.254, train_los Epoch 6: 100%|█| 47/47 [00:01<00:00, 42.33batch/s, test_accuracy=0.285, test_loss=13.2, train_accuracy=0.287, train_los Epoch 7: 100%|█| 47/47 [00:02<00:00, 16.35batch/s, test_accuracy=0.337, test_loss=12.2, train_accuracy=0.33, train_loss Epoch 8: 100%|█| 47/47 [00:02<00:00, 21.64batch/s, test_accuracy=0.337, test_loss=12.2, train_accuracy=0.332, train_los Epoch 9: 100%|█| 47/47 [00:01<00:00, 29.36batch/s, test_accuracy=0.358, test_loss=11.8, train_accuracy=0.34, train_loss Epoch 10: 100%|█| 47/47 [00:01<00:00, 24.24batch/s, test_accuracy=0.405, test_loss=11, train_accuracy=0.409, train_loss Epoch 11: 100%|█| 47/47 [00:01<00:00, 28.28batch/s, test_accuracy=0.427, test_loss=10.5, train_accuracy=0.403, train_lo Epoch 12: 100%|█| 47/47 [00:01<00:00, 30.93batch/s, test_accuracy=0.478, test_loss=9.61, train_accuracy=0.448, train_lo Epoch 13: 100%|█| 47/47 [00:01<00:00, 36.17batch/s, test_accuracy=0.54, test_loss=8.47, train_accuracy=0.516, train_los Epoch 14: 100%|█| 47/47 [00:01<00:00, 35.90batch/s, test_accuracy=0.52, test_loss=8.84, train_accuracy=0.503, train_los Epoch 15: 100%|█| 47/47 [00:01<00:00, 33.76batch/s, test_accuracy=0.502, test_loss=9.17, train_accuracy=0.491, train_lo Epoch 16: 100%|█| 47/47 [00:01<00:00, 35.51batch/s, test_accuracy=0.571, test_loss=7.9, train_accuracy=0.558, train_los Epoch 17: 100%|█| 47/47 [00:01<00:00, 35.52batch/s, test_accuracy=0.552, test_loss=8.25, train_accuracy=0.542, train_lo Epoch 18: 100%|█| 47/47 [00:01<00:00, 34.92batch/s, test_accuracy=0.574, test_loss=7.84, train_accuracy=0.555, train_lo Epoch 19: 100%|█| 47/47 [00:01<00:00, 32.98batch/s, test_accuracy=0.582, test_loss=7.7, train_accuracy=0.557, train_los Epoch 20: 100%|█| 47/47 [00:01<00:00, 34.64batch/s, test_accuracy=0.615, test_loss=7.09, train_accuracy=0.587, train_lo Epoch 21: 100%|█| 47/47 [00:01<00:00, 36.74batch/s, test_accuracy=0.619, test_loss=7.01, train_accuracy=0.587, train_lo Epoch 22: 100%|█| 47/47 [00:01<00:00, 36.93batch/s, test_accuracy=0.607, test_loss=7.24, train_accuracy=0.592, train_lo Epoch 23: 100%|█| 47/47 [00:01<00:00, 35.18batch/s, test_accuracy=0.664, test_loss=6.19, train_accuracy=0.625, train_lo Epoch 24: 100%|█| 47/47 [00:01<00:00, 36.72batch/s, test_accuracy=0.619, test_loss=7.01, train_accuracy=0.593, train_lo Epoch 25: 100%|█| 47/47 [00:01<00:00, 35.51batch/s, test_accuracy=0.682, test_loss=5.85, train_accuracy=0.664, train_lo Epoch 26: 100%|█| 47/47 [00:01<00:00, 35.74batch/s, test_accuracy=0.679, test_loss=5.91, train_accuracy=0.645, train_lo Epoch 27: 100%|█| 47/47 [00:01<00:00, 36.62batch/s, test_accuracy=0.668, test_loss=6.11, train_accuracy=0.639, train_lo Epoch 28: 100%|█| 47/47 [00:01<00:00, 33.77batch/s, test_accuracy=0.632, test_loss=6.78, train_accuracy=0.607, train_lo Epoch 29: 100%|█| 47/47 [00:01<00:00, 35.56batch/s, test_accuracy=0.689, test_loss=5.73, train_accuracy=0.667, train_lo Epoch 30: 100%|█| 47/47 [00:01<00:00, 35.01batch/s, test_accuracy=0.703, test_loss=5.47, train_accuracy=0.676, train_lo Epoch 31: 100%|█| 47/47 [00:01<00:00, 36.62batch/s, test_accuracy=0.687, test_loss=5.76, train_accuracy=0.662, train_lo Epoch 32: 100%|█| 47/47 [00:01<00:00, 34.31batch/s, test_accuracy=0.721, test_loss=5.14, train_accuracy=0.705, train_lo Epoch 33: 100%|█| 47/47 [00:01<00:00, 35.22batch/s, test_accuracy=0.72, test_loss=5.16, train_accuracy=0.689, train_los Epoch 34: 100%|█| 47/47 [00:01<00:00, 35.39batch/s, test_accuracy=0.727, test_loss=5.03, train_accuracy=0.695, train_lo Epoch 35: 100%|█| 47/47 [00:01<00:00, 32.02batch/s, test_accuracy=0.724, test_loss=5.08, train_accuracy=0.7, train_loss Epoch 36: 100%|█| 47/47 [00:01<00:00, 36.53batch/s, test_accuracy=0.721, test_loss=5.14, train_accuracy=0.701, train_lo Epoch 37: 100%|█| 47/47 [00:00<00:00, 47.51batch/s, test_accuracy=0.771, test_loss=4.22, train_accuracy=0.744, train_lo Epoch 38: 100%|█| 47/47 [00:01<00:00, 40.25batch/s, test_accuracy=0.749, test_loss=4.62, train_accuracy=0.727, train_lo Epoch 39: 100%|█| 47/47 [00:01<00:00, 39.19batch/s, test_accuracy=0.754, test_loss=4.53, train_accuracy=0.724, train_lo Epoch 40: 100%|█| 47/47 [00:01<00:00, 43.86batch/s, test_accuracy=0.773, test_loss=4.18, train_accuracy=0.75, train_los
Training MLP for learning rate = 0.01
Epoch 1: 100%|█| 47/47 [00:00<00:00, 50.50batch/s, test_accuracy=0.194, test_loss=14.8, train_accuracy=0.193, train_los Epoch 2: 100%|█| 47/47 [00:00<00:00, 48.89batch/s, test_accuracy=0.179, test_loss=15.1, train_accuracy=0.186, train_los Epoch 3: 100%|█| 47/47 [00:01<00:00, 38.61batch/s, test_accuracy=0.433, test_loss=10.4, train_accuracy=0.428, train_los Epoch 4: 100%|█| 47/47 [00:01<00:00, 45.37batch/s, test_accuracy=0.474, test_loss=9.68, train_accuracy=0.485, train_los Epoch 5: 100%|█| 47/47 [00:01<00:00, 45.85batch/s, test_accuracy=0.546, test_loss=8.36, train_accuracy=0.536, train_los Epoch 6: 100%|█| 47/47 [00:00<00:00, 48.28batch/s, test_accuracy=0.562, test_loss=8.06, train_accuracy=0.568, train_los Epoch 7: 100%|█| 47/47 [00:00<00:00, 49.49batch/s, test_accuracy=0.559, test_loss=8.12, train_accuracy=0.538, train_los Epoch 8: 100%|█| 47/47 [00:00<00:00, 47.27batch/s, test_accuracy=0.631, test_loss=6.79, train_accuracy=0.636, train_los Epoch 9: 100%|█| 47/47 [00:00<00:00, 48.24batch/s, test_accuracy=0.687, test_loss=5.76, train_accuracy=0.681, train_los Epoch 10: 100%|█| 47/47 [00:00<00:00, 48.81batch/s, test_accuracy=0.75, test_loss=4.6, train_accuracy=0.746, train_loss Epoch 11: 100%|█| 47/47 [00:00<00:00, 50.21batch/s, test_accuracy=0.738, test_loss=4.82, train_accuracy=0.725, train_lo Epoch 12: 100%|█| 47/47 [00:00<00:00, 48.94batch/s, test_accuracy=0.733, test_loss=4.92, train_accuracy=0.733, train_lo Epoch 13: 100%|█| 47/47 [00:00<00:00, 50.79batch/s, test_accuracy=0.748, test_loss=4.64, train_accuracy=0.748, train_lo Epoch 14: 100%|█| 47/47 [00:00<00:00, 49.54batch/s, test_accuracy=0.802, test_loss=3.65, train_accuracy=0.79, train_los Epoch 15: 100%|█| 47/47 [00:00<00:00, 48.62batch/s, test_accuracy=0.784, test_loss=3.98, train_accuracy=0.785, train_lo Epoch 16: 100%|█| 47/47 [00:00<00:00, 48.33batch/s, test_accuracy=0.804, test_loss=3.61, train_accuracy=0.786, train_lo Epoch 17: 100%|█| 47/47 [00:00<00:00, 49.12batch/s, test_accuracy=0.825, test_loss=3.22, train_accuracy=0.823, train_lo Epoch 18: 100%|█| 47/47 [00:00<00:00, 50.56batch/s, test_accuracy=0.788, test_loss=3.9, train_accuracy=0.791, train_los Epoch 19: 100%|█| 47/47 [00:00<00:00, 49.71batch/s, test_accuracy=0.857, test_loss=2.63, train_accuracy=0.838, train_lo Epoch 20: 100%|█| 47/47 [00:00<00:00, 49.50batch/s, test_accuracy=0.866, test_loss=2.47, train_accuracy=0.852, train_lo Epoch 21: 100%|█| 47/47 [00:00<00:00, 47.82batch/s, test_accuracy=0.809, test_loss=3.52, train_accuracy=0.805, train_lo Epoch 22: 100%|█| 47/47 [00:01<00:00, 44.21batch/s, test_accuracy=0.783, test_loss=4, train_accuracy=0.782, train_loss= Epoch 23: 100%|█| 47/47 [00:00<00:00, 48.62batch/s, test_accuracy=0.864, test_loss=2.5, train_accuracy=0.858, train_los Epoch 24: 100%|█| 47/47 [00:01<00:00, 46.89batch/s, test_accuracy=0.864, test_loss=2.5, train_accuracy=0.852, train_los Epoch 25: 100%|█| 47/47 [00:00<00:00, 48.58batch/s, test_accuracy=0.872, test_loss=2.36, train_accuracy=0.867, train_lo Epoch 26: 100%|█| 47/47 [00:00<00:00, 48.40batch/s, test_accuracy=0.882, test_loss=2.17, train_accuracy=0.871, train_lo Epoch 27: 100%|█| 47/47 [00:00<00:00, 50.29batch/s, test_accuracy=0.865, test_loss=2.49, train_accuracy=0.86, train_los Epoch 28: 100%|█| 47/47 [00:00<00:00, 48.77batch/s, test_accuracy=0.876, test_loss=2.28, train_accuracy=0.868, train_lo Epoch 29: 100%|█| 47/47 [00:00<00:00, 48.04batch/s, test_accuracy=0.86, test_loss=2.58, train_accuracy=0.861, train_los Epoch 30: 100%|█| 47/47 [00:00<00:00, 48.77batch/s, test_accuracy=0.887, test_loss=2.08, train_accuracy=0.884, train_lo Epoch 31: 100%|█| 47/47 [00:00<00:00, 49.22batch/s, test_accuracy=0.889, test_loss=2.04, train_accuracy=0.883, train_lo Epoch 32: 100%|█| 47/47 [00:01<00:00, 45.53batch/s, test_accuracy=0.883, test_loss=2.15, train_accuracy=0.862, train_lo Epoch 33: 100%|█| 47/47 [00:00<00:00, 48.42batch/s, test_accuracy=0.897, test_loss=1.9, train_accuracy=0.886, train_los Epoch 34: 100%|█| 47/47 [00:00<00:00, 51.62batch/s, test_accuracy=0.898, test_loss=1.88, train_accuracy=0.884, train_lo Epoch 35: 100%|█| 47/47 [00:00<00:00, 50.73batch/s, test_accuracy=0.898, test_loss=1.88, train_accuracy=0.89, train_los Epoch 36: 100%|█| 47/47 [00:00<00:00, 47.89batch/s, test_accuracy=0.897, test_loss=1.9, train_accuracy=0.889, train_los Epoch 37: 100%|█| 47/47 [00:00<00:00, 48.14batch/s, test_accuracy=0.895, test_loss=1.93, train_accuracy=0.882, train_lo Epoch 38: 100%|█| 47/47 [00:00<00:00, 48.27batch/s, test_accuracy=0.887, test_loss=2.08, train_accuracy=0.88, train_los Epoch 39: 100%|█| 47/47 [00:01<00:00, 46.83batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.895, train_lo Epoch 40: 100%|█| 47/47 [00:00<00:00, 48.33batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.896, train_lo
Training MLP for learning rate = 0.046415888336127774
Epoch 1: 100%|█| 47/47 [00:00<00:00, 50.65batch/s, test_accuracy=0.152, test_loss=15.6, train_accuracy=0.151, train_los Epoch 2: 100%|█| 47/47 [00:00<00:00, 49.92batch/s, test_accuracy=0.337, test_loss=12.2, train_accuracy=0.33, train_loss Epoch 3: 100%|█| 47/47 [00:00<00:00, 51.43batch/s, test_accuracy=0.618, test_loss=7.03, train_accuracy=0.62, train_loss Epoch 4: 100%|█| 47/47 [00:00<00:00, 49.22batch/s, test_accuracy=0.578, test_loss=7.77, train_accuracy=0.578, train_los Epoch 5: 100%|█| 47/47 [00:00<00:00, 50.35batch/s, test_accuracy=0.796, test_loss=3.76, train_accuracy=0.789, train_los Epoch 6: 100%|█| 47/47 [00:00<00:00, 49.76batch/s, test_accuracy=0.711, test_loss=5.32, train_accuracy=0.71, train_loss Epoch 7: 100%|█| 47/47 [00:00<00:00, 51.60batch/s, test_accuracy=0.791, test_loss=3.85, train_accuracy=0.8, train_loss= Epoch 8: 100%|█| 47/47 [00:00<00:00, 50.46batch/s, test_accuracy=0.864, test_loss=2.5, train_accuracy=0.864, train_loss Epoch 9: 100%|█| 47/47 [00:00<00:00, 51.83batch/s, test_accuracy=0.881, test_loss=2.19, train_accuracy=0.877, train_los Epoch 10: 100%|█| 47/47 [00:00<00:00, 48.38batch/s, test_accuracy=0.856, test_loss=2.65, train_accuracy=0.858, train_lo Epoch 11: 100%|█| 47/47 [00:01<00:00, 39.87batch/s, test_accuracy=0.868, test_loss=2.43, train_accuracy=0.871, train_lo Epoch 12: 100%|█| 47/47 [00:01<00:00, 46.71batch/s, test_accuracy=0.886, test_loss=2.1, train_accuracy=0.875, train_los Epoch 13: 100%|█| 47/47 [00:00<00:00, 47.94batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.903, train_lo Epoch 14: 100%|█| 47/47 [00:00<00:00, 47.69batch/s, test_accuracy=0.877, test_loss=2.26, train_accuracy=0.876, train_lo Epoch 15: 100%|█| 47/47 [00:00<00:00, 48.98batch/s, test_accuracy=0.893, test_loss=1.97, train_accuracy=0.887, train_lo Epoch 16: 100%|█| 47/47 [00:00<00:00, 48.53batch/s, test_accuracy=0.907, test_loss=1.71, train_accuracy=0.906, train_lo Epoch 17: 100%|█| 47/47 [00:00<00:00, 48.01batch/s, test_accuracy=0.882, test_loss=2.17, train_accuracy=0.871, train_lo Epoch 18: 100%|█| 47/47 [00:01<00:00, 44.37batch/s, test_accuracy=0.903, test_loss=1.79, train_accuracy=0.898, train_lo Epoch 19: 100%|█| 47/47 [00:00<00:00, 48.46batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.903, train_loss Epoch 20: 100%|█| 47/47 [00:01<00:00, 44.71batch/s, test_accuracy=0.911, test_loss=1.64, train_accuracy=0.909, train_lo Epoch 21: 100%|█| 47/47 [00:00<00:00, 47.01batch/s, test_accuracy=0.856, test_loss=2.65, train_accuracy=0.865, train_lo Epoch 22: 100%|█| 47/47 [00:00<00:00, 49.48batch/s, test_accuracy=0.91, test_loss=1.66, train_accuracy=0.917, train_los Epoch 23: 100%|█| 47/47 [00:00<00:00, 47.92batch/s, test_accuracy=0.902, test_loss=1.8, train_accuracy=0.904, train_los Epoch 24: 100%|█| 47/47 [00:00<00:00, 50.19batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.916, train_lo Epoch 25: 100%|█| 47/47 [00:00<00:00, 49.77batch/s, test_accuracy=0.906, test_loss=1.73, train_accuracy=0.908, train_lo Epoch 26: 100%|█| 47/47 [00:00<00:00, 48.60batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.919, train_lo Epoch 27: 100%|█| 47/47 [00:00<00:00, 49.87batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.905, train_loss Epoch 28: 100%|█| 47/47 [00:00<00:00, 49.64batch/s, test_accuracy=0.915, test_loss=1.56, train_accuracy=0.917, train_lo Epoch 29: 100%|█| 47/47 [00:01<00:00, 45.16batch/s, test_accuracy=0.912, test_loss=1.62, train_accuracy=0.922, train_lo Epoch 30: 100%|█| 47/47 [00:00<00:00, 50.30batch/s, test_accuracy=0.907, test_loss=1.71, train_accuracy=0.919, train_lo Epoch 31: 100%|█| 47/47 [00:01<00:00, 43.81batch/s, test_accuracy=0.926, test_loss=1.36, train_accuracy=0.932, train_lo Epoch 32: 100%|█| 47/47 [00:02<00:00, 19.07batch/s, test_accuracy=0.904, test_loss=1.77, train_accuracy=0.91, train_los Epoch 33: 100%|█| 47/47 [00:03<00:00, 13.59batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.928, train_lo Epoch 34: 100%|█| 47/47 [00:01<00:00, 25.69batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.928, train_lo Epoch 35: 100%|█| 47/47 [00:01<00:00, 27.09batch/s, test_accuracy=0.917, test_loss=1.53, train_accuracy=0.933, train_lo Epoch 36: 100%|█| 47/47 [00:01<00:00, 28.26batch/s, test_accuracy=0.921, test_loss=1.45, train_accuracy=0.929, train_lo Epoch 37: 100%|█| 47/47 [00:01<00:00, 37.08batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.937, train_lo Epoch 38: 100%|█| 47/47 [00:01<00:00, 35.14batch/s, test_accuracy=0.922, test_loss=1.44, train_accuracy=0.933, train_lo Epoch 39: 100%|█| 47/47 [00:01<00:00, 33.26batch/s, test_accuracy=0.928, test_loss=1.33, train_accuracy=0.939, train_lo Epoch 40: 100%|█| 47/47 [00:01<00:00, 26.46batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.942, train_lo
Training MLP for learning rate = 0.21544346900318823
Epoch 1: 100%|█| 47/47 [00:02<00:00, 21.76batch/s, test_accuracy=0.195, test_loss=14.8, train_accuracy=0.192, train_los Epoch 2: 100%|█| 47/47 [00:01<00:00, 27.31batch/s, test_accuracy=0.415, test_loss=10.8, train_accuracy=0.406, train_los Epoch 3: 100%|█| 47/47 [00:01<00:00, 30.72batch/s, test_accuracy=0.659, test_loss=6.28, train_accuracy=0.652, train_los Epoch 4: 100%|█| 47/47 [00:01<00:00, 31.26batch/s, test_accuracy=0.759, test_loss=4.44, train_accuracy=0.765, train_los Epoch 5: 100%|█| 47/47 [00:01<00:00, 33.99batch/s, test_accuracy=0.813, test_loss=3.44, train_accuracy=0.806, train_los Epoch 6: 100%|█| 47/47 [00:01<00:00, 34.58batch/s, test_accuracy=0.855, test_loss=2.67, train_accuracy=0.864, train_los Epoch 7: 100%|█| 47/47 [00:01<00:00, 34.80batch/s, test_accuracy=0.876, test_loss=2.28, train_accuracy=0.881, train_los Epoch 8: 100%|█| 47/47 [00:01<00:00, 34.06batch/s, test_accuracy=0.855, test_loss=2.67, train_accuracy=0.853, train_los Epoch 9: 100%|█| 47/47 [00:01<00:00, 29.78batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.898, train_loss= Epoch 10: 100%|█| 47/47 [00:01<00:00, 32.49batch/s, test_accuracy=0.89, test_loss=2.03, train_accuracy=0.899, train_los Epoch 11: 100%|█| 47/47 [00:01<00:00, 32.83batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.894, train_loss Epoch 12: 100%|█| 47/47 [00:01<00:00, 35.07batch/s, test_accuracy=0.874, test_loss=2.32, train_accuracy=0.881, train_lo Epoch 13: 100%|█| 47/47 [00:01<00:00, 33.77batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.927, train_lo Epoch 14: 100%|█| 47/47 [00:01<00:00, 35.80batch/s, test_accuracy=0.93, test_loss=1.29, train_accuracy=0.934, train_los Epoch 15: 100%|█| 47/47 [00:01<00:00, 35.50batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.928, train_lo Epoch 16: 100%|█| 47/47 [00:01<00:00, 36.19batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.931, train_lo Epoch 17: 100%|█| 47/47 [00:01<00:00, 34.07batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.944, train_lo Epoch 18: 100%|█| 47/47 [00:01<00:00, 35.74batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.945, train_lo Epoch 19: 100%|█| 47/47 [00:01<00:00, 35.86batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.947, train_lo Epoch 20: 100%|█| 47/47 [00:01<00:00, 35.22batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.947, train_lo Epoch 21: 100%|█| 47/47 [00:01<00:00, 33.57batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.955, train_los Epoch 22: 100%|█| 47/47 [00:01<00:00, 38.66batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.948, train_lo Epoch 23: 100%|█| 47/47 [00:01<00:00, 45.04batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.954, train_lo Epoch 24: 100%|█| 47/47 [00:01<00:00, 46.61batch/s, test_accuracy=0.92, test_loss=1.47, train_accuracy=0.936, train_los Epoch 25: 100%|█| 47/47 [00:01<00:00, 39.51batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.937, train_lo Epoch 26: 100%|█| 47/47 [00:01<00:00, 45.32batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.948, train_los Epoch 27: 100%|█| 47/47 [00:01<00:00, 45.22batch/s, test_accuracy=0.946, test_loss=0.994, train_accuracy=0.964, train_l Epoch 28: 100%|█| 47/47 [00:01<00:00, 44.73batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.959, train_lo Epoch 29: 100%|█| 47/47 [00:01<00:00, 45.66batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.96, train_los Epoch 30: 100%|█| 47/47 [00:01<00:00, 42.64batch/s, test_accuracy=0.949, test_loss=0.939, train_accuracy=0.977, train_l Epoch 31: 100%|█| 47/47 [00:01<00:00, 44.88batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.95, train_los Epoch 32: 100%|█| 47/47 [00:01<00:00, 45.53batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.971, train_lo Epoch 33: 100%|█| 47/47 [00:01<00:00, 45.36batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.973, train_lo Epoch 34: 100%|█| 47/47 [00:01<00:00, 44.21batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.973, train_lo Epoch 35: 100%|█| 47/47 [00:01<00:00, 42.73batch/s, test_accuracy=0.944, test_loss=1.03, train_accuracy=0.981, train_lo Epoch 36: 100%|█| 47/47 [00:01<00:00, 43.72batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.975, train_lo Epoch 37: 100%|█| 47/47 [00:01<00:00, 41.21batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.978, train_lo Epoch 38: 100%|█| 47/47 [00:01<00:00, 44.30batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.968, train_lo Epoch 39: 100%|█| 47/47 [00:01<00:00, 44.29batch/s, test_accuracy=0.58, test_loss=7.73, train_accuracy=0.596, train_los Epoch 40: 100%|█| 47/47 [00:01<00:00, 46.87batch/s, test_accuracy=0.94, test_loss=1.1, train_accuracy=0.979, train_loss
Training MLP for learning rate = 1.0
Epoch 1: 100%|█| 47/47 [00:00<00:00, 48.34batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 2: 100%|█| 47/47 [00:00<00:00, 48.87batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 3: 100%|█| 47/47 [00:01<00:00, 44.84batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 4: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 5: 100%|█| 47/47 [00:00<00:00, 49.57batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 6: 100%|█| 47/47 [00:00<00:00, 48.61batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 7: 100%|█| 47/47 [00:00<00:00, 47.60batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 8: 100%|█| 47/47 [00:01<00:00, 46.72batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 9: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16 Epoch 10: 100%|█| 47/47 [00:01<00:00, 45.53batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 11: 100%|█| 47/47 [00:00<00:00, 47.22batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 12: 100%|█| 47/47 [00:00<00:00, 48.88batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 13: 100%|█| 47/47 [00:00<00:00, 47.05batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 14: 100%|█| 47/47 [00:00<00:00, 47.90batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 15: 100%|█| 47/47 [00:00<00:00, 47.84batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 16: 100%|█| 47/47 [00:01<00:00, 46.27batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 17: 100%|█| 47/47 [00:00<00:00, 47.17batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 18: 100%|█| 47/47 [00:00<00:00, 48.37batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 19: 100%|█| 47/47 [00:01<00:00, 46.61batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 20: 100%|█| 47/47 [00:01<00:00, 44.13batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 21: 100%|█| 47/47 [00:00<00:00, 49.69batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 22: 100%|█| 47/47 [00:00<00:00, 48.09batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 23: 100%|█| 47/47 [00:01<00:00, 43.99batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 24: 100%|█| 47/47 [00:01<00:00, 43.79batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 25: 100%|█| 47/47 [00:01<00:00, 46.11batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 26: 100%|█| 47/47 [00:01<00:00, 46.49batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 27: 100%|█| 47/47 [00:00<00:00, 48.56batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 28: 100%|█| 47/47 [00:01<00:00, 46.84batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 29: 100%|█| 47/47 [00:00<00:00, 48.07batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 30: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 31: 100%|█| 47/47 [00:01<00:00, 44.04batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 32: 100%|█| 47/47 [00:00<00:00, 47.35batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 33: 100%|█| 47/47 [00:01<00:00, 42.90batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 34: 100%|█| 47/47 [00:01<00:00, 45.96batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 35: 100%|█| 47/47 [00:00<00:00, 47.03batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 36: 100%|█| 47/47 [00:01<00:00, 42.50batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 37: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 38: 100%|█| 47/47 [00:01<00:00, 41.04batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 39: 100%|█| 47/47 [00:01<00:00, 45.27batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1 Epoch 40: 100%|█| 47/47 [00:01<00:00, 43.53batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
# plotting
plt.plot(learning_rates, train_losses,'.-',label='Train loss')
plt.plot(learning_rates, test_losses,'.-',label='Test loss')
plt.xscale('log')
plt.xlabel('Learning rate')
plt.ylabel('Final loss')
plt.title('Final Losses against Learning rates')
plt.legend()
plt.show()
In finding the relations between the losses and the learning rate, options for learning rates from $10^{-4}$ to $10^{0}$ are split into 7 points at equal distance.
From the plot, the final loss at epoch=40 for both train and test data have a decreasing trend, meaning that a too small learning rate ($lr$) is getting close but not enough to converge to the optimal point.
However, the loss explodes in the final bit, meaning that a big step makes the final loss sway unstablely around the optimal point and hardly converges to the optimal point.
Optimal learning rate should be set at a value where both the train loss is minimal in the restricted interval and right before the test loss is exploding.
Note that this doesn't imply any information about overfitting or underfitting.
best_lr = learning_rates[-2]
print("Best learning rate:", best_lr)
epochs = 40
batch_size = 128
best_w,best_b, best_train_loss, best_train_acc, best_test_loss, best_test_acc = train(x_train, y_train, x_test, y_test, best_lr, epochs,batch_size)
Best learning rate: 0.21544346900318823
Epoch 1: 100%|█| 47/47 [00:01<00:00, 46.95batch/s, test_accuracy=0.121, test_loss=16.2, train_accuracy=0.13, train_loss Epoch 2: 100%|█| 47/47 [00:01<00:00, 43.55batch/s, test_accuracy=0.359, test_loss=11.8, train_accuracy=0.356, train_los Epoch 3: 100%|█| 47/47 [00:01<00:00, 45.47batch/s, test_accuracy=0.545, test_loss=8.38, train_accuracy=0.536, train_los Epoch 4: 100%|█| 47/47 [00:01<00:00, 46.50batch/s, test_accuracy=0.602, test_loss=7.33, train_accuracy=0.584, train_los Epoch 5: 100%|█| 47/47 [00:01<00:00, 44.83batch/s, test_accuracy=0.719, test_loss=5.17, train_accuracy=0.723, train_los Epoch 6: 100%|█| 47/47 [00:01<00:00, 46.84batch/s, test_accuracy=0.783, test_loss=4, train_accuracy=0.777, train_loss=4 Epoch 7: 100%|█| 47/47 [00:01<00:00, 46.47batch/s, test_accuracy=0.854, test_loss=2.69, train_accuracy=0.859, train_los Epoch 8: 100%|█| 47/47 [00:00<00:00, 47.17batch/s, test_accuracy=0.858, test_loss=2.61, train_accuracy=0.865, train_los Epoch 9: 100%|█| 47/47 [00:01<00:00, 30.05batch/s, test_accuracy=0.887, test_loss=2.08, train_accuracy=0.893, train_los Epoch 10: 100%|█| 47/47 [00:03<00:00, 14.71batch/s, test_accuracy=0.888, test_loss=2.06, train_accuracy=0.888, train_lo Epoch 11: 100%|█| 47/47 [00:01<00:00, 25.39batch/s, test_accuracy=0.89, test_loss=2.03, train_accuracy=0.894, train_los Epoch 12: 100%|█| 47/47 [00:01<00:00, 27.07batch/s, test_accuracy=0.903, test_loss=1.79, train_accuracy=0.905, train_lo Epoch 13: 100%|█| 47/47 [00:01<00:00, 25.46batch/s, test_accuracy=0.902, test_loss=1.8, train_accuracy=0.911, train_los Epoch 14: 100%|█| 47/47 [00:01<00:00, 27.23batch/s, test_accuracy=0.903, test_loss=1.79, train_accuracy=0.916, train_lo Epoch 15: 100%|█| 47/47 [00:01<00:00, 34.49batch/s, test_accuracy=0.904, test_loss=1.77, train_accuracy=0.911, train_lo Epoch 16: 100%|█| 47/47 [00:01<00:00, 34.07batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.938, train_lo Epoch 17: 100%|█| 47/47 [00:01<00:00, 33.49batch/s, test_accuracy=0.92, test_loss=1.47, train_accuracy=0.926, train_los Epoch 18: 100%|█| 47/47 [00:01<00:00, 34.99batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.94, train_los Epoch 19: 100%|█| 47/47 [00:01<00:00, 34.34batch/s, test_accuracy=0.909, test_loss=1.68, train_accuracy=0.913, train_lo Epoch 20: 100%|█| 47/47 [00:01<00:00, 34.63batch/s, test_accuracy=0.927, test_loss=1.34, train_accuracy=0.943, train_lo Epoch 21: 100%|█| 47/47 [00:01<00:00, 35.23batch/s, test_accuracy=0.886, test_loss=2.1, train_accuracy=0.898, train_los Epoch 22: 100%|█| 47/47 [00:01<00:00, 31.07batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.951, train_lo Epoch 23: 100%|█| 47/47 [00:01<00:00, 34.50batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.952, train_lo Epoch 24: 100%|█| 47/47 [00:01<00:00, 34.34batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.958, train_lo Epoch 25: 100%|█| 47/47 [00:01<00:00, 34.11batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.952, train_lo Epoch 26: 100%|█| 47/47 [00:01<00:00, 33.95batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.957, train_lo Epoch 27: 100%|█| 47/47 [00:01<00:00, 34.12batch/s, test_accuracy=0.915, test_loss=1.56, train_accuracy=0.939, train_lo Epoch 28: 100%|█| 47/47 [00:01<00:00, 33.37batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.947, train_lo Epoch 29: 100%|█| 47/47 [00:01<00:00, 32.59batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.953, train_lo Epoch 30: 100%|█| 47/47 [00:01<00:00, 33.33batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.962, train_lo Epoch 31: 100%|█| 47/47 [00:01<00:00, 35.62batch/s, test_accuracy=0.926, test_loss=1.36, train_accuracy=0.96, train_los Epoch 32: 100%|█| 47/47 [00:01<00:00, 34.43batch/s, test_accuracy=0.93, test_loss=1.29, train_accuracy=0.957, train_los Epoch 33: 100%|█| 47/47 [00:01<00:00, 33.92batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.976, train_lo Epoch 34: 100%|█| 47/47 [00:01<00:00, 34.72batch/s, test_accuracy=0.947, test_loss=0.976, train_accuracy=0.971, train_l Epoch 35: 100%|█| 47/47 [00:01<00:00, 34.19batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.969, train_lo Epoch 36: 100%|█| 47/47 [00:01<00:00, 31.55batch/s, test_accuracy=0.944, test_loss=1.03, train_accuracy=0.977, train_lo Epoch 37: 100%|█| 47/47 [00:01<00:00, 33.98batch/s, test_accuracy=0.952, test_loss=0.884, train_accuracy=0.983, train_l Epoch 38: 100%|█| 47/47 [00:01<00:00, 35.29batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.978, train_los Epoch 39: 100%|█| 47/47 [00:01<00:00, 40.85batch/s, test_accuracy=0.949, test_loss=0.939, train_accuracy=0.985, train_l Epoch 40: 100%|█| 47/47 [00:01<00:00, 44.46batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.984, train_lo
# print best_train_loss
best_train_loss
[16.02347237663499, 11.856510391290234, 8.54257891546377, 7.668069220454009, 5.102840781758711, 4.111729794080982, 2.6051183546080248, 2.4885170619400565, 1.9699481550746192, 2.0681387173213293, 1.9515374246533614, 1.7490193900195217, 1.6324180973515536, 1.5465013553856826, 1.6446919176323926, 1.1476021962584233, 1.3623940511731014, 1.1107807354159072, 1.5986650915792473, 1.0524800890819228, 1.8686891377576997, 0.9021257906416481, 0.8775781500799706, 0.7732506776928413, 0.8898519703608094, 0.8008667733247284, 1.1322599209073747, 0.9819056224671, 0.865304329799132, 0.6996077560078087, 0.7425661269907444, 0.7855244979736801, 0.4449259851804047, 0.5277742720760663, 0.5768695531994212, 0.4173098895485175, 0.32218778237201723, 0.4019676141974691, 0.27309250124866224, 0.30377705195075905]
# plotting
fig,(ax1,ax2) = plt.subplots(1,2, figsize=(15,5))
ax1.plot(range(epochs),best_train_loss,label=f'Training Loss')
ax1.plot(range(epochs),best_test_loss,label=f'Test Loss')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.grid()
ax2.plot(range(epochs),best_train_acc,label=f'Training accuracy')
ax2.plot(range(epochs),best_test_acc,label=f'Test accuracy')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.grid()
ax1.legend()
ax2.legend()
plt.suptitle(f"Outputs from MLP, with best learning rate = {best_lr:.4f}")
plt.show()
### adding the print-outs of final losses
print("The final losses of training and test data are: ",best_train_loss[-1], best_test_loss[-1], "repectively.")
print("The final accuracies of training and test data are: ",best_train_acc[-1], best_test_acc[-1], "repectively.")
The final losses of training and test data are: 0.30377705195075905 1.1230545556967457 repectively. The final accuracies of training and test data are: 0.9835 0.939 repectively.
# retrain using optimal learning rate
best_lr = learning_rates[np.argmin(train_losses)]
width = 50
epochs=40
batch_size = 128
best_w_2,best_b_2, best_train_loss_2, best_train_acc_2, best_test_loss_2,best_test_acc_2 = train(x_train, y_train, x_test, y_test, best_lr, epochs, width=width, batch_size= batch_size)
Epoch 1: 100%|█| 47/47 [00:00<00:00, 168.91batch/s, test_accuracy=0.459, test_loss=9.96, train_accuracy=0.452, train_lo Epoch 2: 100%|█| 47/47 [00:00<00:00, 141.95batch/s, test_accuracy=0.673, test_loss=6.02, train_accuracy=0.677, train_lo Epoch 3: 100%|█| 47/47 [00:00<00:00, 143.24batch/s, test_accuracy=0.649, test_loss=6.46, train_accuracy=0.648, train_lo Epoch 4: 100%|█| 47/47 [00:00<00:00, 161.75batch/s, test_accuracy=0.852, test_loss=2.72, train_accuracy=0.864, train_lo Epoch 5: 100%|█| 47/47 [00:00<00:00, 166.55batch/s, test_accuracy=0.873, test_loss=2.34, train_accuracy=0.881, train_lo Epoch 6: 100%|█| 47/47 [00:00<00:00, 151.04batch/s, test_accuracy=0.885, test_loss=2.12, train_accuracy=0.891, train_lo Epoch 7: 100%|█| 47/47 [00:00<00:00, 161.94batch/s, test_accuracy=0.903, test_loss=1.79, train_accuracy=0.913, train_lo Epoch 8: 100%|█| 47/47 [00:00<00:00, 162.50batch/s, test_accuracy=0.914, test_loss=1.58, train_accuracy=0.92, train_los Epoch 9: 100%|█| 47/47 [00:00<00:00, 165.59batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.915, train_lo Epoch 10: 100%|█| 47/47 [00:00<00:00, 116.65batch/s, test_accuracy=0.913, test_loss=1.6, train_accuracy=0.917, train_lo Epoch 11: 100%|█| 47/47 [00:00<00:00, 159.21batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.929, train_l Epoch 12: 100%|█| 47/47 [00:00<00:00, 170.74batch/s, test_accuracy=0.911, test_loss=1.64, train_accuracy=0.911, train_l Epoch 13: 100%|█| 47/47 [00:00<00:00, 137.39batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.935, train_l Epoch 14: 100%|█| 47/47 [00:00<00:00, 167.71batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.949, train_l Epoch 15: 100%|█| 47/47 [00:00<00:00, 168.31batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.946, train_lo Epoch 16: 100%|█| 47/47 [00:00<00:00, 170.61batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.956, train_l Epoch 17: 100%|█| 47/47 [00:00<00:00, 174.54batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.954, train_lo Epoch 18: 100%|█| 47/47 [00:00<00:00, 165.94batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.965, train_l Epoch 19: 100%|█| 47/47 [00:00<00:00, 167.11batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.965, train_lo Epoch 20: 100%|█| 47/47 [00:00<00:00, 160.84batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.967, train_l Epoch 21: 100%|█| 47/47 [00:00<00:00, 162.17batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.926, train_l Epoch 22: 100%|█| 47/47 [00:00<00:00, 155.02batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.966, train_l Epoch 23: 100%|█| 47/47 [00:00<00:00, 163.63batch/s, test_accuracy=0.88, test_loss=2.21, train_accuracy=0.901, train_lo Epoch 24: 100%|█| 47/47 [00:00<00:00, 182.66batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.973, train_l Epoch 25: 100%|█| 47/47 [00:00<00:00, 160.29batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.979, train_l Epoch 26: 100%|█| 47/47 [00:00<00:00, 156.56batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.978, train_l Epoch 27: 100%|█| 47/47 [00:00<00:00, 171.37batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.977, train_l Epoch 28: 100%|█| 47/47 [00:00<00:00, 173.26batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.977, train_l Epoch 29: 100%|█| 47/47 [00:00<00:00, 166.36batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.984, train_l Epoch 30: 100%|█| 47/47 [00:00<00:00, 160.29batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.986, train_l Epoch 31: 100%|█| 47/47 [00:00<00:00, 159.13batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.985, train_l Epoch 32: 100%|█| 47/47 [00:00<00:00, 119.00batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.986, train_l Epoch 33: 100%|█| 47/47 [00:00<00:00, 154.51batch/s, test_accuracy=0.927, test_loss=1.34, train_accuracy=0.981, train_l Epoch 34: 100%|█| 47/47 [00:00<00:00, 167.71batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.991, train_l Epoch 35: 100%|█| 47/47 [00:00<00:00, 157.40batch/s, test_accuracy=0.945, test_loss=1.01, train_accuracy=0.994, train_l Epoch 36: 100%|█| 47/47 [00:00<00:00, 164.78batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.987, train_l Epoch 37: 100%|█| 47/47 [00:00<00:00, 166.25batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.994, train_l Epoch 38: 100%|█| 47/47 [00:00<00:00, 166.52batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.989, train_lo Epoch 39: 100%|█| 47/47 [00:00<00:00, 170.75batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.993, train_l Epoch 40: 100%|█| 47/47 [00:00<00:00, 152.51batch/s, test_accuracy=0.944, test_loss=1.03, train_accuracy=0.997, train_l
# print best_train_loss_2
best_train_loss_2
[10.098285636060082, 5.946665926066375, 6.47750865321265, 2.513064702501734, 2.1816715549190877, 2.000632705776716, 1.5925281814388281, 1.4697899786304403, 1.5649120858069405, 1.5188852597537956, 1.3040934048391173, 1.6446919176323926, 1.1997659324519878, 0.9450841616245838, 0.9880425326075194, 0.816209048675777, 0.8560989645885029, 0.6535809299546634, 0.6474440198142442, 0.6014171937610987, 1.3531886859624722, 0.6259648343227762, 1.8287992218449738, 0.5032266315143887, 0.3927622489868399, 0.4111729794080981, 0.432652164899566, 0.4203783446187272, 0.2915032316699203, 0.26388713603803315, 0.2792294113890816, 0.2516133157571944, 0.34673542293369475, 0.15955966365090377, 0.11966974773817782, 0.23933949547635563, 0.11046438252754877, 0.20251803463383938, 0.12887511294880688, 0.05830064633398407]
fig,(ax1,ax2) = plt.subplots(1,2, figsize=(15,5))
ax1.plot(range(40),best_train_acc,label=f'Training accuracy for width = 200')
ax1.plot(range(40),best_test_acc,label=f'Test accuracy for width = 200')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Accuracy")
ax1.grid()
ax2.plot(range(40),best_train_acc_2,label=f'Training accuracy for width = 50')
ax2.plot(range(40),best_test_acc_2,label=f'Test accuracy for width = 50')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.grid()
ax1.legend()
ax2.legend()
plt.suptitle(f"Effect of Increasing Epoches on Accuracies, with best learning rate = {best_lr:.4f}")
plt.show()
print("For width = 200:")
print("Final Train Accuracy: ", best_train_acc[-1])
print("Final Test Accuracy: ", best_test_acc[-1])
print("\nFor width = 50:")
print("Final Train Accuracy: ", best_train_acc_2[-1])
print("Final Test Accuracy: ", best_test_acc_2[-1])
For width = 200: Final Train Accuracy: 0.9835 Final Test Accuracy: 0.939 For width = 50: Final Train Accuracy: 0.9968333333333333 Final Test Accuracy: 0.944
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(range(40), best_train_loss_2, label=f'Training loss for width = 50')
ax.plot(range(40), best_test_loss_2, label=f'Test loss for width = 50')
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.grid()
ax.legend()
plt.show()
print("The final losses of training and test data are: ",best_train_loss_2[-1], best_test_loss_2[-1], "repectively.")
The final losses of training and test data are: 0.05830064633398407 1.031000903590455 repectively.
The width of a hidden layer refers to the number of neurons or units in that layer. Similar to the model complexity for linear regression, the more neurons per layer, the more complex patterns the MLP can learn from the data.
From the print-outs, the final loss for the training and test data set are both higher for reduced-width model. This could mean that the model in 1.2.2 performs better, with better fitted test data. In addition, it can be observed that the accuracies of training data and test data are closer to each other. This indicates that model in 1.2.2 fits the unseen data better
However, for both models, the accuracies for the training and test data are increasing, but the accuracies for training and particularly the test data remains stable at width=50. This could be the model with reduced part has had converged.
Given more time, a tolerance can be set to return a signal once the test accuracy and test loss falls within certain criterion for both models.
# retrain MLP using optimal learning rate
# retrain using optimal learning rate
lr = learning_rates[np.argmin(train_losses)]
epochs = 50 # scaled epochs
batch_size = 128
dropout_prob = 0.2
w_dropout,b_dropout,train_loss_dropout, train_acc_dropout,test_loss_dropout,test_acc_dropout = train(x_train, y_train, x_test, y_test, lr, epochs,batch_size,dropout_prob=dropout_prob)
Epoch 1: 100%|█| 47/47 [00:01<00:00, 45.84batch/s, test_accuracy=0.134, test_loss=15.9, train_accuracy=0.131, train_los Epoch 2: 100%|█| 47/47 [00:01<00:00, 46.44batch/s, test_accuracy=0.403, test_loss=11, train_accuracy=0.404, train_loss= Epoch 3: 100%|█| 47/47 [00:01<00:00, 45.38batch/s, test_accuracy=0.69, test_loss=5.71, train_accuracy=0.692, train_loss Epoch 4: 100%|█| 47/47 [00:01<00:00, 45.08batch/s, test_accuracy=0.721, test_loss=5.14, train_accuracy=0.735, train_los Epoch 5: 100%|█| 47/47 [00:01<00:00, 46.80batch/s, test_accuracy=0.845, test_loss=2.85, train_accuracy=0.841, train_los Epoch 6: 100%|█| 47/47 [00:01<00:00, 46.83batch/s, test_accuracy=0.848, test_loss=2.8, train_accuracy=0.85, train_loss= Epoch 7: 100%|█| 47/47 [00:01<00:00, 46.55batch/s, test_accuracy=0.882, test_loss=2.17, train_accuracy=0.876, train_los Epoch 8: 100%|█| 47/47 [00:01<00:00, 46.68batch/s, test_accuracy=0.888, test_loss=2.06, train_accuracy=0.888, train_los Epoch 9: 100%|█| 47/47 [00:00<00:00, 48.02batch/s, test_accuracy=0.876, test_loss=2.28, train_accuracy=0.887, train_los Epoch 10: 100%|█| 47/47 [00:01<00:00, 46.00batch/s, test_accuracy=0.885, test_loss=2.12, train_accuracy=0.894, train_lo Epoch 11: 100%|█| 47/47 [00:00<00:00, 47.89batch/s, test_accuracy=0.906, test_loss=1.73, train_accuracy=0.896, train_lo Epoch 12: 100%|█| 47/47 [00:00<00:00, 47.49batch/s, test_accuracy=0.902, test_loss=1.8, train_accuracy=0.902, train_los Epoch 13: 100%|█| 47/47 [00:00<00:00, 47.07batch/s, test_accuracy=0.908, test_loss=1.69, train_accuracy=0.911, train_lo Epoch 14: 100%|█| 47/47 [00:01<00:00, 44.17batch/s, test_accuracy=0.921, test_loss=1.45, train_accuracy=0.917, train_lo Epoch 15: 100%|█| 47/47 [00:01<00:00, 46.10batch/s, test_accuracy=0.916, test_loss=1.55, train_accuracy=0.914, train_lo Epoch 16: 100%|█| 47/47 [00:01<00:00, 46.74batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.926, train_lo Epoch 17: 100%|█| 47/47 [00:01<00:00, 46.50batch/s, test_accuracy=0.915, test_loss=1.56, train_accuracy=0.919, train_lo Epoch 18: 100%|█| 47/47 [00:00<00:00, 47.82batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.924, train_lo Epoch 19: 100%|█| 47/47 [00:01<00:00, 45.38batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.927, train_los Epoch 20: 100%|█| 47/47 [00:01<00:00, 38.98batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.932, train_lo Epoch 21: 100%|█| 47/47 [00:01<00:00, 45.84batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.926, train_lo Epoch 22: 100%|█| 47/47 [00:01<00:00, 45.01batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.935, train_lo Epoch 23: 100%|█| 47/47 [00:01<00:00, 46.06batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.942, train_lo Epoch 24: 100%|█| 47/47 [00:01<00:00, 45.89batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.94, train_los Epoch 25: 100%|█| 47/47 [00:01<00:00, 45.56batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.939, train_lo Epoch 26: 100%|█| 47/47 [00:01<00:00, 44.18batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.94, train_los Epoch 27: 100%|█| 47/47 [00:01<00:00, 45.98batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.948, train_lo Epoch 28: 100%|█| 47/47 [00:01<00:00, 45.18batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.947, train_lo Epoch 29: 100%|█| 47/47 [00:01<00:00, 43.39batch/s, test_accuracy=0.94, test_loss=1.1, train_accuracy=0.949, train_loss Epoch 30: 100%|█| 47/47 [00:01<00:00, 44.37batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.951, train_lo Epoch 31: 100%|█| 47/47 [00:01<00:00, 43.85batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.953, train_lo Epoch 32: 100%|█| 47/47 [00:01<00:00, 45.51batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.953, train_lo Epoch 33: 100%|█| 47/47 [00:01<00:00, 44.97batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.955, train_lo Epoch 34: 100%|█| 47/47 [00:01<00:00, 44.16batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.955, train_lo Epoch 35: 100%|█| 47/47 [00:01<00:00, 43.88batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.952, train_lo Epoch 36: 100%|█| 47/47 [00:01<00:00, 46.57batch/s, test_accuracy=0.94, test_loss=1.1, train_accuracy=0.958, train_loss Epoch 37: 100%|█| 47/47 [00:01<00:00, 45.73batch/s, test_accuracy=0.94, test_loss=1.1, train_accuracy=0.961, train_loss Epoch 38: 100%|█| 47/47 [00:01<00:00, 44.71batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.955, train_lo Epoch 39: 100%|█| 47/47 [00:01<00:00, 45.30batch/s, test_accuracy=0.947, test_loss=0.976, train_accuracy=0.963, train_l Epoch 40: 100%|█| 47/47 [00:01<00:00, 41.96batch/s, test_accuracy=0.947, test_loss=0.976, train_accuracy=0.963, train_l Epoch 41: 100%|█| 47/47 [00:01<00:00, 44.53batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.961, train_lo Epoch 42: 100%|█| 47/47 [00:01<00:00, 34.88batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.964, train_lo Epoch 43: 100%|█| 47/47 [00:01<00:00, 32.89batch/s, test_accuracy=0.944, test_loss=1.03, train_accuracy=0.965, train_lo Epoch 44: 100%|█| 47/47 [00:01<00:00, 27.37batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.965, train_lo Epoch 45: 100%|█| 47/47 [00:02<00:00, 23.36batch/s, test_accuracy=0.945, test_loss=1.01, train_accuracy=0.967, train_lo Epoch 46: 100%|█| 47/47 [00:01<00:00, 25.63batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.97, train_los Epoch 47: 100%|█| 47/47 [00:01<00:00, 29.98batch/s, test_accuracy=0.946, test_loss=0.994, train_accuracy=0.97, train_lo Epoch 48: 100%|█| 47/47 [00:01<00:00, 31.86batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.968, train_lo Epoch 49: 100%|█| 47/47 [00:01<00:00, 34.30batch/s, test_accuracy=0.945, test_loss=1.01, train_accuracy=0.971, train_lo Epoch 50: 100%|█| 47/47 [00:01<00:00, 33.27batch/s, test_accuracy=0.948, test_loss=0.957, train_accuracy=0.971, train_l
# print train_loss with drop-out rates
print(train_loss_dropout)
[16.001993191143523, 10.98200069628047, 5.673573424817713, 4.888048926844032, 2.9273061369800417, 2.767746473329138, 2.285999027306217, 2.0681387173213293, 2.0865494477425877, 1.9576743347937808, 1.9085790536704257, 1.8073200363535058, 1.6354865524217632, 1.5311590800346342, 1.5802543611579891, 1.3623940511731012, 1.4851322539814888, 1.3930786018751982, 1.3439833207518432, 1.2488612135753427, 1.356257141032682, 1.1874921121711492, 1.0616854542925518, 1.1046438252754875, 1.1261230107669553, 1.101575370205278, 0.9573579819054225, 0.9849740775373096, 0.9420157065543742, 0.9113311558522773, 0.865304329799132, 0.8622358747289223, 0.8376882341672447, 0.834619779097035, 0.892920425431019, 0.7701822226226317, 0.718018486429067, 0.8254144138864061, 0.68733393572697, 0.6904023907971797, 0.7241553965694862, 0.6689232053057119, 0.6413071096738248, 0.6535809299546635, 0.6044856488313084, 0.5492534575675341, 0.5553903677079536, 0.5922118285504697, 0.5431165474271147, 0.5277742720760663]
fig,(ax1,ax2) = plt.subplots(1,2, figsize=(15,5))
ax1.plot(range(epochs),train_loss_dropout,label=f'Training Loss')
ax1.plot(range(epochs),test_loss_dropout,label=f'Test Loss')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.grid()
ax2.plot(range(epochs),train_acc_dropout,label=f'Training accuracy')
ax2.plot(range(epochs),test_acc_dropout,label=f'Test accuracy')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.grid()
ax1.legend()
ax2.legend()
plt.suptitle(f"Outputs from MLP, with dropout rate = {dropout_prob} and learning rate = {lr: 4f}")
plt.show()
print("Final Dropout Train Loss: ", train_loss_dropout[-1])
print("Final Dropout Test Loss: ", test_loss_dropout[-1])
print("Final Dropout Train Accuracy: ", train_acc_dropout[-1])
print("Final Dropout Test Accuracy: ", test_acc_dropout[-1])
Final Dropout Train Loss: 0.5277742720760663 Final Dropout Test Loss: 0.9573579819054225 Final without Dropout Train Accuracy: 0.9713333333333334 Final without Dropout Test Accuracy: 0.948
# the first batch
x_batch_test = x_test[:128]
# no dropout
outputs_nd = forward_prop(x_batch_test, best_w, best_b)
# with dropout = 0.2
outputs_d = forward_prop(x_batch_test, best_w, best_b, dropout_prob = 0.2)
# plotting activations of the first hidden layer
activations_nd = outputs_nd[1]['A1'].flatten()
activations_d = outputs_d[1]['A1'].flatten()
plt.hist(outputs_nd[1]['A1'].flatten(), bins=100, density=False, alpha=0.5, label="Without Dropout")
plt.hist(outputs_d[1]['A1'].flatten(), bins=100, density=False, alpha=0.5, label="With Dropout Rate = 0.2")
plt.legend()
plt.show()
print("Sum of activations in MLP without dropout: ", np.sum(outputs_nd[1]['A1']))
print("Sum of activations in MLP with dropout rate = 0.2: ", np.sum(outputs_d[1]['A1']))
print("\nPercentage of zero in activations with dropout:", (len(activations_d[activations_d<1e-10])/len(activations_d))*100,'%')
Sum of activations in MLP without dropout: 20555.677530687786 Sum of activations in MLP with dropout rate = 0.2: 20618.161134698887 Percentage of zero in activations with dropout: 19.78125 %
Dropout is a technique of regularization to prevent overfitting. It randomly replaces elemts of each col of $W$ with zeros with dropout_prob. By doing this, the capacity of the model is reduced and it prevents the neurons from co-adapting each other too much. Neurons in the network have less dependence on other specific neurons being present, and so each neuron learns features that are more robust, and generalises better.
It can be observed from the plots without doubt that the model performs better with dropout:
$\cdot$ less fluctuating losses and accuracies,
$\cdot$ much narrower gaps for training and test data from smaller epochs onwards and,
$\cdot$ higher test accuracies(from $\boldsymbol {0.939}$ to $\boldsymbol {0.948}$).
They indicate that the model reduces disruption of noise and could have converged, and the model generalizes more robustly to unseen data.
References: http://proceedings.mlr.press/v48/gal16.pdf http://proceedings.mlr.press/v31/damianou13a.pdf
The histogram with dropout is more Gaussian distributed than that without dropout. This can be interpreted in terms of Deep Gaussian process:
Deep Gaussian process is a probablistic substitute for MLP. And in the scenario of this part, the MLP with 1 hidden layer is equivalent to a standard Gaussian process (from refs).
The approach to Gaussian process modelling is to place a prior directly over a class of functions and integrate them out. In MLP, dropping-out can be viewed as a Baysian model averaging, which processes inputs into subsets of neurons and are corresponding to non-zero assignments of Bernoulli mask. By doing so, dropout can reduce the co-adapting of the network on individual neurons, making it learn more distributed representations. This latter is analogous to the former, in the way that a Gaussian process places a prior over functions, encouraging them to be smooth and avoiding overfitting to noisy data.
Therefore, using dropout, the first hidden layer activations is more like a Gaussian process, with smooth and well-behaved functions that generalize well to new data. This corresponds to the behaviour of the orange histogram (without considerations of dropped zeros). And a better generalization corresponds to the trend shown in the loss and accuracy plots.
For various applications of NMF and PCA, denoising is applied in this section.
For given image matrices, PCA is implemented to see the trend of variance explained by principle components (eigenvectors of the matrix) and for both methods, the first 10 basis components are visualized and differences are explained.
Due to the different natures of the decomposition approaches, effects of denoising for 100 pc are discussed for both methods by reconstructed images. One important thing to notice is that: in the process of training, noisy images have been normalized by first being divided by 255 and then by standard procedure of standardization. Therefore, to compare reconstructed images and original noisy and unnoisy images, reconstructed images should be reverted back by multiplying 'sigma' and adding 'mu' stored ealier. After this step, all images are at the same scale.
Finally, MSE is used as a measure of performance of PCA to see how it behaves wrt the test noisy data and the unnoisy data. Similarly, the reconsrtructed images should be reverted to its original scale by the same steps mentioned above.
# read txt file
MNIST_train_noisy = np.loadtxt('MNIST_train_noisy.txt')
MNIST_test_noisy = np.loadtxt('MNIST_test_noisy.txt')
# inspecting the data
print("The shapes of the training and test data are: ", MNIST_train_noisy.shape, MNIST_test_noisy.shape)
print("The type of the MNIST_train_noisy and MNIST_test_noisy are: ", type(MNIST_test_noisy), type(MNIST_train_noisy))
print("The max and min of the trainning noisy data: ", np.min(MNIST_train_noisy), np.max(MNIST_train_noisy))
The shapes of the training and test data are: (6000, 784) (1000, 784) The type of the MNIST_train_noisy and MNIST_test_noisy are: <class 'numpy.ndarray'> <class 'numpy.ndarray'> The max and min of the trainning noisy data: -107.31 343.8
MNIST_train_noisy[0] # intergers with Gaussian noise
array([ 1.46928e+01, 4.15443e+01, 6.50907e-01, 1.84006e+00,
-1.37111e+01, 2.70121e+00, 6.83062e+00, 2.09018e+00,
1.32685e+01, -1.81447e+01, 7.60745e+00, -4.30438e+01,
-8.82577e+00, 1.29603e+01, -1.42484e+01, -2.25878e+01,
-3.95282e+01, -3.76398e+00, -4.38008e+00, 1.06019e+01,
2.69545e+01, 1.72741e+00, -2.04378e+01, 7.15860e+00,
1.14836e+01, 8.48694e+00, 2.12904e+01, 3.88620e+00,
-2.23335e+01, -1.54388e+01, -3.57090e+01, -5.21508e+00,
-1.84571e+01, -1.35396e+01, 2.37871e+01, 1.47583e+01,
-1.28852e+01, -6.30926e+00, -8.74880e+00, -1.16587e+01,
7.42091e-01, 6.38451e+00, -1.37609e+01, -1.92919e+01,
2.79193e+01, 6.01952e+00, -7.07875e+00, 1.47998e+00,
8.53116e+00, -1.10761e+01, -1.63237e+01, 1.98356e+01,
4.32072e+00, 4.94797e+01, 1.72346e+01, -1.43560e+01,
9.83602e+00, -9.93811e+00, -2.12885e+01, -6.20323e+00,
-2.99198e+01, 1.86351e+01, 9.14288e+00, 2.57839e+01,
2.11070e+01, 3.32087e+01, 1.57824e+00, -1.72484e+01,
-1.39058e+01, -2.15084e+01, 5.44566e+01, -1.97676e+00,
3.61940e+00, 1.54496e+01, 1.26325e+01, -2.08133e+01,
-9.79608e+00, 4.42669e+01, 1.40790e+01, 1.54795e+01,
-7.34956e+00, 1.70580e+01, -2.84354e+01, -2.64093e+01,
-1.03846e+01, -1.28380e+01, -2.95061e+01, -1.24579e+00,
1.59867e+01, 9.92727e+00, -3.32957e-01, -2.42489e+01,
-1.33814e+01, 7.15406e+00, 7.60305e+00, -1.74021e+01,
-2.56823e+01, -2.96517e+01, 1.03331e+01, 1.31809e+01,
6.63456e+00, -9.38165e+00, -3.31801e+01, -3.04959e-01,
5.39579e+01, -6.06997e+00, -3.28984e+01, 3.67116e+00,
-1.18050e+00, 4.47812e+00, 3.83445e+01, 6.34854e+00,
-1.45135e+01, -5.61027e+00, 1.48787e+01, 2.09539e+01,
4.73510e-01, -1.81998e+01, -2.59578e+01, -1.35279e+01,
-7.65255e+00, 1.41188e+01, 1.87509e+00, 2.72302e+01,
-1.04744e+01, 1.91733e+01, 2.74841e+01, -2.85886e+01,
-1.35327e+01, -1.90314e+01, 1.18274e+01, 2.56375e+01,
1.93300e+01, 1.84234e+01, 7.69585e+00, -4.21003e+01,
1.64824e+01, -7.53810e+00, 1.66392e+01, -7.31290e+00,
-7.70799e+00, -3.15550e+00, 1.93177e+01, -2.63962e+01,
1.96105e+00, -1.28918e+01, 1.60156e+01, -1.19533e+01,
1.80504e+01, 5.75270e+00, -9.36589e+00, 8.39003e+00,
-3.04006e+01, -2.74955e+01, 2.28516e+01, 2.19562e+02,
1.97703e+02, 6.60478e+00, -9.21046e+00, -8.67712e+00,
7.40744e+00, 3.67294e+00, -1.69200e+01, -3.03753e+00,
-1.64380e+01, 1.94649e+01, -4.17285e+01, 1.83169e+01,
1.21473e+01, 1.69646e+01, -4.00270e+00, -2.88538e+00,
5.05098e+00, -1.12919e+01, -4.43433e+01, -3.31128e+01,
-2.34549e+01, 2.34587e+01, -9.08712e-01, -2.25376e+01,
-3.62749e+01, 2.39294e+01, 1.25245e+02, 2.55935e+02,
1.89210e+02, 3.83215e+01, -3.21856e+01, -1.65434e+01,
-2.95512e+01, -2.49321e+01, 2.25440e+01, 2.31692e+01,
1.36295e+01, -1.04602e+00, 5.22427e+00, -1.37121e+00,
-7.52505e+00, 1.18560e+01, -7.37543e+00, 2.58198e+01,
1.03998e+01, -2.96805e+01, -2.07470e+00, 9.65127e+00,
-2.05366e+01, -8.97081e+00, 2.15243e+01, -9.38292e+00,
1.35506e+01, 6.92266e+00, 1.62373e+02, 2.70264e+02,
2.45413e+02, 4.78852e+01, 2.77413e+01, 1.34091e+01,
-8.51889e+00, -2.64588e+01, -5.93605e+00, 2.32397e+01,
2.16447e+01, 2.99154e+01, 2.39283e-01, 9.20799e+00,
-4.50518e+00, 2.87857e+00, -1.30913e+01, 3.20019e+01,
4.33217e+01, -8.93203e+00, 2.51179e+01, 4.06348e+00,
5.54683e+00, -8.03559e+00, -1.85558e+00, -2.13523e+01,
2.53617e+01, -4.36704e+01, 1.88918e+02, 3.00853e+02,
2.32465e+02, 1.40830e+02, 9.86603e+00, -3.40927e+01,
-8.55709e+00, 8.53139e+00, 2.64744e+00, 1.82964e+01,
-1.90327e+01, 3.55915e+01, 1.77847e+01, 1.47957e+01,
6.43721e+00, 2.03183e+00, 1.01528e+01, 5.51030e+00,
-9.87462e+00, 9.83367e+00, 1.12899e+01, 1.90928e+01,
9.20189e+00, 2.73134e+01, 3.64131e+01, 2.11858e+01,
-7.10387e+00, 3.61107e+01, 2.45559e+02, 2.87364e+02,
2.39861e+02, 1.77156e+02, 2.87905e+01, 1.07451e+01,
2.13438e+01, -1.34293e+01, 2.14320e+01, 2.61130e+01,
-3.71112e+01, -1.03756e+01, -7.07590e-01, 2.06512e+01,
8.28111e+00, -3.65680e+01, -1.18821e+01, 4.34086e+00,
-2.44816e+01, 3.04381e+00, -6.55714e+00, -1.43526e+01,
-6.57490e+00, 1.49190e+00, -3.34429e+01, -1.75223e+01,
1.19131e+01, 2.08288e+02, 2.79990e+02, 2.42012e+02,
2.17104e+02, 1.96761e+02, 1.43511e+01, 2.13817e+01,
-2.17823e+01, -2.00719e+01, 7.89365e-01, 6.57631e+00,
-4.68865e+00, -7.88515e+00, -3.64777e+01, -4.27916e+01,
-1.23525e+01, 4.56614e+00, 3.53103e-01, 1.16850e+01,
1.72275e+01, 1.31295e+01, 2.38088e+01, -2.48391e+01,
3.24215e+01, -2.36893e+01, 1.61307e+01, -2.42744e+01,
6.36573e+01, 2.18915e+02, 1.50615e+02, 1.40312e+02,
2.70070e+02, 1.79195e+02, -4.44998e+00, -2.29097e+01,
-1.55775e+01, 2.20191e+00, -2.21838e+01, 1.84893e+00,
-1.84677e+01, -2.17944e+00, 2.30816e+01, 3.12202e+01,
-2.32265e+01, 7.87255e+00, 1.69636e+01, -2.93580e+01,
-4.90506e+00, -3.79856e+01, 2.45714e+01, -6.42660e+00,
-2.06966e+01, 1.01039e+01, -1.17117e+00, 3.78489e+01,
2.60123e+02, 2.72533e+02, 1.40536e+02, 7.33374e+01,
2.58281e+02, 1.64409e+02, 3.76308e+01, 4.11274e+00,
1.10785e+01, 2.19105e+01, 1.34282e+01, -3.59380e+00,
1.01468e+01, -2.86313e+01, 8.58247e+00, 7.98355e-01,
-5.65881e+00, -1.65159e+01, -2.48610e+01, 2.16931e+01,
1.19902e+01, -3.43677e+01, 6.95947e-01, 1.39839e+01,
2.26090e+01, 1.36541e+01, 2.45212e+01, 2.14205e+02,
2.44182e+02, 2.15888e+02, 2.22461e+01, 1.13725e+02,
2.50662e+02, 1.74208e+02, -7.26047e+00, 6.55291e+00,
3.30733e+00, 3.06851e+00, 2.28504e+01, 5.80103e+00,
-3.66832e+01, 4.71170e+00, 1.17757e+01, 1.62027e+01,
-1.55880e+00, 1.76639e+01, 1.93807e+01, -7.50557e+00,
-1.25410e+01, -3.97182e-01, -1.90590e+00, -1.76868e+01,
-1.36247e+01, -1.73211e+00, 3.85518e+01, 3.01332e+02,
2.73972e+02, 7.70106e+01, -1.14361e+01, 1.86369e+02,
2.46977e+02, 8.31455e+01, 5.65317e+00, 1.27294e+01,
1.00203e+01, -1.17402e+01, 2.77662e+01, -4.27329e+01,
-3.56599e+01, -5.54958e+01, -1.32558e+01, -2.81824e+01,
1.74667e+01, -7.21089e+00, -3.31630e+01, 1.59204e+01,
3.75129e+00, 2.59298e+00, 1.26323e+01, 9.97398e+00,
3.53983e+01, 6.54483e+01, 2.17958e+02, 2.49798e+02,
1.95730e+02, 9.34401e+01, 1.19577e+02, 2.95815e+02,
2.35680e+02, 3.95958e+01, 2.04519e+01, 2.23056e+01,
1.04728e+01, 4.13824e+01, -7.01298e+00, -1.56546e+01,
-1.09702e-01, -1.72753e-01, 3.00668e+01, 9.38009e+00,
-9.37427e+00, -3.48430e+00, -1.08003e-01, 1.65164e+01,
-2.92456e+01, 4.36536e+01, 1.77725e+01, -8.17721e+00,
9.64735e+01, 1.95233e+02, 2.55611e+02, 2.49306e+02,
2.35548e+02, 2.35231e+02, 2.49989e+02, 2.76956e+02,
2.74830e+02, 2.48525e+02, 2.06920e+01, -1.45384e+01,
-2.18008e+01, 9.40981e+00, -1.21378e+01, 6.55521e+00,
1.50937e+01, -9.53465e+00, 1.30961e+01, -1.44901e+01,
3.25047e+01, 3.11608e+01, -7.52132e+00, 3.32673e-01,
3.02300e+01, -1.09056e+01, 4.43964e+01, 4.05151e+01,
2.66073e+02, 2.51470e+02, 2.78116e+02, 2.46771e+02,
2.49153e+02, 2.28313e+02, 1.75841e+02, 2.50854e+02,
2.41130e+02, 1.76525e+02, 8.21311e+01, -1.76462e+01,
9.79665e-01, -5.03212e+01, -3.38977e+01, -3.71078e+01,
-3.43196e+01, -1.11945e+01, -4.05062e+01, 1.04211e+01,
2.35890e+01, 1.46437e+00, 8.63347e+00, -3.71576e-01,
-8.21729e+00, -1.24452e+01, 2.97835e+01, 9.99399e+01,
2.26863e+02, 1.92009e+02, 1.56521e+02, 4.45907e+01,
2.20701e+01, -1.00044e+01, 1.46868e+01, 2.55228e+02,
2.17912e+02, -1.88956e+01, -2.26498e+01, 6.99782e+00,
4.11815e+00, -2.00987e+01, 7.12709e+00, 1.81945e+00,
-2.72156e+00, 9.83292e+00, 2.06573e+00, 7.48773e+00,
4.96712e+00, -1.87271e+01, 4.35234e+00, -1.12788e+01,
3.84628e+01, -1.40378e+01, 7.41670e-01, -1.06709e+01,
5.39856e+00, 8.87864e+00, -5.78087e+00, 3.24254e+01,
3.37935e+00, 2.86719e+00, 4.89003e+01, 2.56846e+02,
2.53318e+02, -2.20285e+00, 4.06285e+01, 1.69746e+00,
-6.53814e+00, 1.45286e+01, 4.93218e+00, 2.58272e+01,
1.44117e+01, -3.02833e+01, 2.17766e+00, 1.12453e+01,
2.45699e+01, -5.09358e+00, 3.80909e+01, 2.44727e+01,
-5.77559e+00, -7.32648e+00, -3.51943e+01, -2.95459e+01,
-1.25109e+00, -2.91962e+01, 2.20334e+01, 9.05947e+00,
1.36272e+01, -4.89239e+00, 7.97818e+01, 2.45809e+02,
2.73546e+02, 6.35912e+01, 1.20074e+01, 2.35910e+00,
-7.86053e-01, 3.78986e+01, 1.37321e+01, 2.14848e+01,
-1.61202e+01, 2.09294e+01, 9.21594e+00, -8.41174e+00,
-2.13135e+01, -9.36809e+00, 7.07539e+00, -1.29072e+01,
1.97590e+01, 9.11854e+00, -1.73637e+00, 2.76084e+01,
-2.90821e+01, -4.88655e+00, 8.15024e-01, 4.01875e+00,
-2.78145e+01, 2.27606e+00, 3.85835e+01, 2.55895e+02,
2.46100e+02, 5.92354e+01, 1.00394e+01, -2.47892e+01,
-1.09727e+01, -3.66841e+01, 2.02300e+01, -5.14658e+00,
6.00640e+00, 2.49677e+01, 5.55689e+00, 4.93382e+01,
-2.31855e+01, -4.19057e+01, -8.98082e+00, -7.06535e+00,
1.63045e+01, 1.78965e+00, 2.69667e+01, -3.55541e+00,
1.28437e+01, 1.73816e+01, 3.39852e+00, 7.77912e+00,
-1.55787e+01, -4.74119e+01, 2.21011e+01, 2.06226e+02,
2.67791e+02, 5.78124e+01, 1.16466e+01, -7.92537e+00,
3.87494e+01, -9.24285e+00, -2.64023e+01, 8.46577e+00,
2.46033e+01, 5.37295e+00, -2.23286e+01, -3.15278e+01,
2.94299e+01, -1.86781e+01, 2.72944e+01, -1.43535e+01,
-1.05593e+01, -3.29578e+01, -1.22854e+01, 8.07984e+00,
3.90399e+00, 2.47789e+01, 1.54150e+01, 3.86269e+00,
1.00901e+01, -2.09984e+01, 2.81658e+01, 2.38916e+02,
2.53203e+02, 9.71723e+01, -1.40918e+01, 6.19697e+00,
2.52356e+01, 6.32046e+00, 3.43176e+00, -1.40803e+01,
-4.50321e+00, -2.11911e+01, 2.42583e+01, 3.18574e+01,
3.24158e+01, -1.76879e+01, 5.04080e+01, 3.46684e+01,
-9.57479e+00, 5.87925e+00, -1.27446e+01, -2.80142e+01,
9.41069e+00, -1.63209e+00, -1.22236e+01, 2.76099e+00,
2.37023e+01, 2.34141e+01, -1.38385e+01, 1.40432e+02,
2.36219e+02, 3.01315e+01, -5.26530e+00, 4.46779e+00,
-2.16320e+01, -1.47344e+00, 1.34955e+01, -2.41033e+01,
1.36706e+01, 5.83552e-01, -1.54566e+00, 6.31859e-01,
6.64487e+00, -1.72672e+01, -1.65740e+00, 1.20156e+01,
5.30149e+00, 4.95793e+00, 1.43906e+01, 2.97536e+01,
-2.87547e+01, 1.10865e+01, -3.50675e+01, -2.82799e-01,
1.72594e+01, 7.98271e-01, -1.63411e+01, 3.80120e+00,
7.58438e+00, 1.45247e+01, -3.70579e+00, 5.45871e+00,
3.97541e+00, 2.03193e+01, -1.78142e+01, 3.45178e+01,
-2.15872e+01, 1.39039e+01, -2.05152e+01, 3.00043e+01,
4.30353e-01, -1.14632e+01, 2.76469e+00, 1.93526e+01,
9.97542e+00, -2.63282e+01, -1.77778e+01, 7.07232e+00,
3.44905e+01, -6.07877e+00, 3.30198e+01, 5.51341e-01,
-3.05719e+01, -1.05455e+01, -2.89520e+01, 5.01119e+00,
1.15149e+01, 4.65715e+01, -3.95418e+01, -2.13420e+01,
-2.35853e+00, 2.30258e-01, -2.91287e+01, -8.13015e+00,
7.83440e+00, 5.71136e+00, 7.19097e+00, -6.76401e+00,
-2.70748e+01, 9.21056e+00, -1.35323e+01, -4.13871e+01,
2.44157e+01, -2.49330e+01, 1.33811e+01, 8.26138e+00,
-7.70961e+00, -2.92449e+00, 1.43662e+01, 3.78538e+00,
-1.25971e+01, 1.95713e+01, -6.30019e+00, -7.03768e+01,
2.76641e+01, 2.68981e+01, 2.58495e-01, 2.87334e+00,
-2.50393e+01, 1.99615e+00, 1.61607e+01, -5.71961e+00,
5.05598e+01, 2.13238e+01, -1.22810e+01, -1.42812e+01])
# plot the first image
plt.figure(figsize=(4,4))
plt.imshow(MNIST_train_noisy[0].reshape(28,28), cmap='gray');
# reshape the data
train_noisy = MNIST_train_noisy.reshape(-1, 28*28)
test_noisy = MNIST_test_noisy.reshape(-1, 28*28)
# reshape the original NMIST data sets
train_original = MNIST_train[:,1:].reshape(-1, 28*28)
test_original = MNIST_test[:, 1:].reshape(-1, 28*28)
# define a function to normalize the data
def normalize(X, mu=None,std=None, return_stats=False, revert=False):
if revert:
return X*std+mu
else:
X = X/255.
if mu is not None and std is not None:
Xbar = ((X-mu)/std)
else:
mu = np.mean(X, axis=0)
std = np.std(X, axis=0)
std_filled = std.copy()
std_filled[std==0] = 1.
Xbar = ((X-mu)/std_filled)
if return_stats:
return mu, std_filled
else:
return Xbar
standardization
mu_pca,std_pca = normalize(train_noisy, return_stats=True)
train_noisy_pca = normalize(train_noisy)
test_noisy_pca = normalize(test_noisy, mu_pca, std_pca)
mu_original, std_original = normalize(train_original, return_stats=True)
train_original_pca = normalize(train_original)
test_original_pca = normalize(test_original, mu=mu_original, std=std_original, revert=False)
print("The shapes of the reshaped training noisy data set is: ", train_noisy_pca.shape)
print("The shape of the reshaped and training original data set is: ", train_original_pca.shape)
The shapes of the reshaped training noisy data set is: (6000, 784) The shape of the reshaped and training original data set is: (6000, 784)
# perform PCA (from coding books)
def pca_function(X, m):
"""
Return the X_pca matrix, the pcs, and corresponding eigen values.
X: data set containing images.
m: number of cpa.
"""
# covariance matrix C
C = 1.0/(len(X)-1) * np.dot(X.T, X)
if m < len(X[0]):
eigenvalues, eigenvectors = scipy.sparse.linalg.eigsh(C, m, which="LM", return_eigenvectors=True)
else:
eigenvalues, eigenvectors = scipy.linalg.eigh(C)
# sorting and eigenvalues from largest to smallest eigenvalue
sorted_index = np.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[sorted_index]
# v[:, i] is the ith e.vec, corresponding the ith e.value
eigenvectors = eigenvectors[:,sorted_index]
X_pca = X.dot(eigenvectors)
return X_pca, eigenvectors, eigenvalues
m = 784 # the row number of the data
X_pca, eigenvectors, eigenvalues = pca_function(train_noisy_pca, m)
var_explained = [evalue/sum(eigenvalues) for evalue in eigenvalues]
# varaicne of pc as m increases
# for multiple pc, the variance explained is the sum of variance over total variance
m_var_explained = np.cumsum(var_explained)
# plot the variance explained against m
plt.figure(figsize=(10,6))
plt.plot(m_var_explained, color='blue', label='Variance Explained')
plt.xlabel('Number of Principle Components')
plt.ylabel('Variance Explained')
plt.title('Variance Explained by Principle Components')
# plot horizontal lines at which the value of variane explained reached 0.7, 0.8 and 0.9
plt.axhline(y=0.7, color='red', linestyle='--', label='70% exlained variance')
plt.axhline(y=0.8, color='green', linestyle='--', label='80% exlained variance')
plt.axhline(y=0.9, color='orange', linestyle='--', label='90% exlained variance')
# print the first value of m when the variance explained is 0.7, 0.8 and 0.9
print("The first value of m when the variance explained is 0.7 is:", np.where(m_var_explained>0.7)[0][0])
print("The first value of m when the variance explained is 0.8 is:", np.where(m_var_explained>0.8)[0][0])
print("The first value of m when the variance explained is 0.9 is:", np.where(m_var_explained>0.9)[0][0])
plt.legend(loc='best', shadow=True, fontsize='x-large')
plt.grid()
plt.show()
The first value of m when the variance explained is 0.7 is: 212 The first value of m when the variance explained is 0.8 is: 297 The first value of m when the variance explained is 0.9 is: 407
# visualize the first 10 principle components
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(ax.reshape(-1)):
ax.imshow(eigenvectors[:, i].reshape(28, 28))
ax.set_ylabel(f"PC{i+1}")
plt.tight_layout();
train_noisy_nmf = (MNIST_train_noisy.reshape(-1, 28 * 28)) / 255.
test_noisy_nmf = (MNIST_test_noisy.reshape(-1, 28 * 28)) / 255.
# max_min normalization
def normalize_nmf(X, min=None, max=None, return_stats = False, revert=False):
if revert:
return X*(max-min)+min
if return_stats:
return X.min(), X.max()
else:
X_norm = (X - X.min()) / (X.max() - X.min())
return X_norm
# normalization
train_noisy_nmf = normalize_nmf(train_noisy_nmf)
min_nmf, max_nmf = normalize_nmf(test_noisy_nmf, return_stats=True)
# define chi2 cost, same as the notebook
def cost(X, W, H):
"""Return the chi2 cost of the NMF decomposition."""
# compute the difference between X and the dot product of W and H
diff = X - np.dot(W, H) ## <-- EDIT THIS LINE
chi2 = ((X*diff) * diff).sum() / (X.shape[0]*X.shape[1])
return chi2
# Implement NMF
# construct placeholder matrices
np.random.seed(0)
m = 10
# m x k components matrix, usually interpreted as the coefficients
W = np.random.rand(train_noisy_nmf.shape[0], m)
# k x n matrix interpreted as the basis set(e.g. pixels)
H = np.random.rand(m, train_noisy_nmf.shape[1])
chi2 = []
n_iters = 200 # the number of iterations
eps = 1e-5 # check for convergence
# loop to find chi2 error against iterations (about 12.5 mins on microsoft)
for i in range(n_iters):
# update first on H
H = H * ((W.T.dot(train_noisy_nmf)) / (W.T.dot(W.dot(H)))) ## <-- EDIT THIS LINE
# the update on W
W = W * ((train_noisy_nmf.dot(H.T)) / (W.dot(H.dot(H.T)))) ## <-- EDIT THIS LINE
# compute the chi2 and append to list
chi2.append(cost(train_noisy_nmf, W, H))
# check for convergence
for i in range(1, len(chi2)):
if abs(chi2[i-1] - chi2[i]) < eps:
print(f"Converged at iteration {i} and the difference is {chi2[i-1] - chi2[i]}")
break
Converged at iteration 82 and the difference is 9.69217069603321e-06
print("The loss for NMF at m=100: ", chi2[99])
The loss for NMF at m=100: 0.007368975662134399
# plot the cost as a function of the number of iterations
plt.plot(chi2, label="Cost")
plt.xlabel("Number of Iterations")
plt.ylabel("chi2 Cost")
plt.title("Cost as a function of the number of iterations")
plt.legend(loc='best', shadow=False, fontsize='x-large')
plt.show()
Particularly, epsilon is added to print the first iteration where the current is of the last ieration is less than eps difference than the last iteration. If the number is far from the chosen number of n_iter and that the cost is monotonically decreasing, we have reasons to believe that the number of iterations chosen is sufficient for converge. As the output and the plot shows, the fisrt iteration that meets the criterion is 82 and the cost curve is decreasing. Therefore, n_iter=200 is a suitable choice to make sure that the cost converges.
# visualize the m=10 components of NMF
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(ax.reshape(-1)):
ax.imshow(H[i].reshape(28, 28)) # H[i] here represents the i-th positive eigenvector
ax.set_ylabel(f"NMF{i+1}")
plt.tight_layout();
# visualize the first 10 principle components of PCA
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(ax.reshape(-1)):
ax.imshow(eigenvectors[:, i].reshape(28, 28))
ax.set_ylabel(f"PC{i+1}")
plt.tight_layout();
The components of PCA and NMF are so different. This arises from their different natures of the way of decomposition.
The components of PCA are called 'eigenfaces', as each principle component of PCA is a modified version of the images. When combining together linearly, a lot of cancellations will be involved.
The components of NMF have a direct visual meaning: it contains local and sparse features of the images.
In all, each imshow shown for PCA components is a global representation of an image, but for NMF, each imshow is a partial and sparse representation of the image.
(references: https://www.nature.com/articles/44565)
# define a function to compute the mse score
def mse_score(X_either, X_reconstructed):
"""Return the mean square error between the reconstructed and corrupted or corrupted images."""
return np.mean(np.square(X_either - X_reconstructed))
# train on PCA
m = 100
X_pca, train_eigenvectors, train_eigenvalues = pca_function(train_noisy_pca, m)
# the reconstructed images by pca
X_reconstructed_pca = test_noisy_pca @ train_eigenvectors @ train_eigenvectors.T # formula from notes
# train on NMF
# construct placeholder matrices
np.random.seed(0)
W = np.random.rand(train_noisy_nmf.shape[0], m)
H = np.random.rand(m, train_noisy_nmf.shape[1])
chi2 = []
n_iters = 200 # the number of iterations
# loop over (about 20 mins on microsoft)
for i in range(n_iters):
# update first on H
H = H * ((W.T.dot(train_noisy_nmf)) / (W.T.dot(W.dot(H))))
W = W * ((train_noisy_nmf.dot(H.T)) / (W.dot(H.dot(H.T))))
chi2.append(cost(train_noisy_nmf, W, H))
# reconstructed images by NMF (formula derived from notes: rows of H are analogue to principle components in PCA)
X_reconstructed_nmf = test_noisy_nmf @ H.T @ H
revert the images back to its original scale
X_reconstructed_pca_revert = normalize(X_reconstructed_pca,mu_pca,std_pca,revert=True)
X_reconstructed_nmf_revert = normalize_nmf(X_reconstructed_nmf,min_nmf,max_nmf,revert=True)
np.random.seed(0)
# randomly choose an image from the noisy test data set
random_image = np.random.randint(0, test_noisy_pca.shape[0])
# plot the noisy, the reconsrtucted and the original image in a row
plt.figure(figsize=(10, 6))
plt.subplot(2, 2, 1)
plt.imshow(test_noisy[random_image].reshape(28, 28))
plt.title("Noisy Image")
# plot the reconstructed image by PCA
plt.subplot(2, 2, 2)
plt.imshow(X_reconstructed_pca_revert[random_image].reshape(28, 28))
plt.title("Reconstructed Image by PCA")
# plot the reconstructed image by NMF
plt.subplot(2, 2, 3)
plt.imshow(X_reconstructed_nmf_revert[random_image].reshape(28, 28))
plt.title("Reconstructed Image by NMF")
# plot the original image
plt.subplot(2, 2, 4)
plt.imshow(test_original[random_image].reshape(28, 28))
plt.title("Original Image")
plt.tight_layout()
plt.show()
In terms of denoising, NMF performs better, as the feature distribution is more like the original image. But in terms of visualizing images, PCA perfroms better, as it's much clearer than the reconstructed image by NMF.
Explanation: NMF has the following properties
The reconstructed images are built using a linear combination of different local parts of the image(the non-negativeness of H only gives addition). This means that not all available local features are used in the linear combination. Therefore, due to the properties mentioned above, sparse addition of features is giving much less noise but also is less likely to give as many features.
On the contrary, for PCA:
The reconstructed images are built on a linear combination, not only with different eigen images, but also both additions and subtractions are allowed. The use of the whole image(rather than sparse and local features) in reconstruction and the complexities of additions and subtractions involved mean that the recontructed image will combine more features as well as noise, thus giving an better-visualized but more noisy image.
(ref again from 1.2.2: https://www.nature.com/articles/44565)
# define a function to compute the mse score
def mse_score(X_either, X_reconstructed):
"""Return the mean square error between the reconstructed and corrupted or corrupted images."""
return np.mean(np.square(X_either - X_reconstructed))
np.random.seed(10)
m_range = np.arange(5, 601, 5) # pc values
random_image = np.random.randint(0, test_noisy_pca.shape[0])
with_uncorrupted_mse_score_lis = []
with_corrupted_mse_score_lis = []
pca_examples_holder = []
for m in m_range:
_, train_eigenvectors, _ = pca_function(train_noisy_pca, m)
X_reconstructed_pca = test_noisy_pca @ train_eigenvectors @ train_eigenvectors.T
X_reconstructed_pca_revert = normalize(X_reconstructed_pca, mu_pca, std_pca, revert=True)
with_uncorrupted_mse_score = mse_score(test_original/255, X_reconstructed_pca_revert)
with_corrupted_mse_score = mse_score(test_noisy/255, X_reconstructed_pca_revert)
with_uncorrupted_mse_score_lis.append(with_uncorrupted_mse_score)
with_corrupted_mse_score_lis.append(with_corrupted_mse_score)
if m in [10, 40, 100, 200, 400, 600]:
X_reconstructed_pca_revert = normalize(X_reconstructed_pca,mu_pca,std_pca,revert=True)
pca_examples_holder.append(X_reconstructed_pca_revert[random_image])
# plot the MSE
plt.figure(figsize=(10, 6))
plt.plot(m_range[0:80] , with_uncorrupted_mse_score_lis[0:80], label="MSE of reconstructed with denoised images")
plt.plot(m_range[0:80], with_corrupted_mse_score_lis[0:80], label="MSE of reconstructed with noisy images")
plt.axvline(100, color="red", linestyle='--')
plt.xlabel('number of principle components')
plt.ylabel('MSE')
plt.title('MSE of images against number of principle components')
plt.legend(loc='best', shadow=False, fontsize='x-large')
plt.show()
# plot the example figure
plt.figure(figsize=(10, 10))
plt.subplot(2,4,1)
plt.imshow(test_original[random_image].reshape(28,28))
plt.ylabel('original test image')
plt.subplot(2,4,2)
plt.imshow(test_noisy[random_image].reshape(28,28))
plt.ylabel('noisy test image')
plt.subplot(2,4,3)
plt.imshow(pca_examples_holder[0].reshape(28,28))
plt.ylabel('m = 10')
plt.subplot(2,4,4)
plt.imshow(pca_examples_holder[1].reshape(28,28))
plt.ylabel('m = 40')
plt.subplot(2,4,5)
plt.imshow(pca_examples_holder[2].reshape(28,28))
plt.ylabel('m = 100')
plt.subplot(2,4,6)
plt.imshow(pca_examples_holder[3].reshape(28,28))
plt.ylabel('m = 200')
plt.subplot(2,4,7)
plt.imshow(pca_examples_holder[4].reshape(28,28))
plt.ylabel('m = 400')
plt.subplot(2,4,8)
plt.imshow(pca_examples_holder[5].reshape(28,28))
plt.ylabel('m = 600')
plt.suptitle("Reconstructed images for test noisy data at different m")
plt.tight_layout()
plt.show()
The MSE decreases fast at about first m=100 components (dotted-red line). This is because, PCA factorizes out principle components in decending order and eigenvectors(which statistically represent variance explained) corresponding to greater eigenvalues bear with more represented features in the reconstructed images. Therefore MSE decreases rapidly at start.
After m=100, the MSE with test data still decreases while MSE with original data remains almost stable. This means that the added principle components help little in denoising the noisy images and thus adding more components is pointless in decreasing MSE. But for test noisy data, the added components are adding more information (but it's just noise) about images to the basis of components and give a decreasing MSE.
The trend above(orange) can be corresponded in the reconstructed test noisy images plotted for different $m$s:
Images become clear rapidly from m=10 to m=40 and m=100.
Images change little, either the background or the number part, from m=100 to m=400, because adding components are just adding more noise.
In all, none of the reconstructed images resemble the original image for the background part, which means that the PCA does not perform well in denoising data. But the reconstructed images are closer to the noisy data, meaning that PCA components fit well to unseen noisy data and is good at visualizing noisy images.
As m further increases, the MSE with reconstructed images and test noisy images are expected to further decrease (which is verified below). This trend can also be seen from the last plot.
# verify the MSE after 400
plt.figure(figsize=(10, 6))
plt.plot(m_range, with_corrupted_mse_score_lis, color='orange', label="MSE of reconstructed with test noisy images")
plt.xlabel('number of principle components')
plt.ylabel('MSE')
plt.title('MSE of images against number of principle components')
plt.legend(loc='best', shadow=False, fontsize='x-large')
plt.show()
In this section, digits' images are reconstructed based on 5 principle components and then are being clustered by GMM models, with 10, 5 and 8 hidden components respectively.
For each of 3 models:
Intuitively, number of hidden components should be at least the number of digits, but for similar types of images, especially those only built on 5 pc, GMM with less hidden components are likely to perform better. The discussion will be included in the section.
Note that in EM algorithm, a convergence check is set up after each EM step to stop the next iteration if the previous mu and sigma are close to avoid overflow.
class GMModel:
"""Define Gaussian Mixture Model class"""
""":param dim: number of mixture components"""
""":param weights: mixture weights"""
""":param mu: mixture component means for each cluster"""
""":param sigma: mixture component covariance matrix for each cluster"""
def __init__(self, X, dim):
"""Initialises parameters through random split of the data"""
self.dim = dim # number of k
# initial weights/ P(Ci=j)/ prior
self.phi = np.full(shape=self.dim, fill_value=1/self.dim) # <- fill the array of shape with values fill_value
# initial weights/ P(Xi/Ci=j)/ likelihood
self.weights = np.full(shape=X.shape, fill_value=1/self.dim)
n, m = X.shape
# as a generator of self.mu
random_row = np.random.randint(low=0, high=n, size=self.dim) # <- could be repeated
# initial value of mean of k Gaussians and sigmas
self.mu = [ X[row_index,:] for row_index in random_row ]
self.sigma = [ np.cov(X.T) for _ in range(self.dim) ]
def cluster_probabilities(gmm, X):
"""Predicts cluster probability for each data point."""
n, m = X.shape
# l_ij = p(x_i|theta_j)
likelihood = np.zeros((n, gmm.dim))
for i in range(gmm.dim):
# likelihood of data belonging to i-th cluster
distribution = multivariate_normal(mean=gmm.mu[i], cov=gmm.sigma[i]) # <- from scipy
likelihood[:,i] = distribution.pdf(X)
numerator = likelihood * gmm.phi
denominator = numerator.sum(axis=1)[:, np.newaxis] # axis=1: col sum: across diff k
weights = numerator / denominator
return weights
def predict(gmm, X):
"""Performs hard clustering"""
weights = cluster_probabilities(gmm, X)
return np.argmax(weights, axis=1)
# implement EM algorithm
def fitStep(gmm, X):
"""Performs an EM step by updating all parameters"""
# E-Step: update weights and phi holding mu and sigma constant: down in "/total_weight"
# M-Step: update mu and sigma holding pi and weights constant
weights = cluster_probabilities(gmm,X)
gmm.phi = weights.mean(axis=0) # prior
for i in range(gmm.dim):
weight = weights[:, [i]]
total_weight = weight.sum()
gmm.mu[i] = (X * weight).sum(axis=0) / total_weight
# bias=True: normalize by num(observations)
gmm.sigma[i] = np.cov(X.T, aweights=(weight/total_weight).flatten(), bias=True)
# train the model with EM. But with a convergence check
def train_gmm(X, n_components, n_iters=1000, eps=1e-8):
gmm = GMModel(X,n_components)
prev_mu = gmm.mu.copy()
prev_std = gmm.sigma.copy()
for i in range(n_iters):
fitStep(gmm,X)
if np.allclose(prev_mu, gmm.mu) and np.allclose(prev_std, gmm.sigma):
print(f"Converged at iteration {i}!")
break
prev_mu = gmm.mu.copy()
prev_std = gmm.sigma.copy()
return gmm
m = 5
np.random.seed(2)
X_pca_131, eigenvectors_131, eigenvalues_131 = pca_function(train_original_pca[0:1000], m)
gmm = GMModel(X_pca_131, 10)
gmm = train_gmm(X_pca_131, n_components=10, n_iters=1000, eps=1e-8)
# hard clustering
cluster_labels = predict(gmm, X_pca_131)
# visualize the space spanned
plt.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=cluster_labels, cmap="tab10")
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.colorbar(label='Cluster Labels')
plt.title('GMM Clustering on MNIST_train on PC1 and PC2, using cluster labels')
plt.show()
Converged at iteration 378!
class_labels = MNIST_train[:1000, 0]
# visualize the space spanned
plt.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=class_labels, cmap="tab10")
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.colorbar(label='Class Labels')
plt.title('GMM Clustering on MNIST_train on PC1 and PC2, using class labels.')
plt.show()
def log_probs(gmm, X):
"""
Return a matrix of shape (n_sample, n_components),
each elemt is the log probability of X_i belonging to the kth mixture component.
"""
n,m = X.shape
log_prob = np.zeros((n, gmm.dim))
for i in range(gmm.dim):
distribution = multivariate_normal(mean=gmm.mu[i],cov=gmm.sigma[i])
# log probability for each x at component i
log_prob[:,i]=distribution.logpdf(X)
return log_prob
def label_cluster_mapping(gmm, X, class_labels):
"""
Return a map of class labels to best_fitting cluster index.
Args:
gmm: Gaussian mixture model.
X: data set
class_labels: given class labels.
"""
all_log_likelihoods = log_probs(gmm, X)
class_log_likelihood = []
for i in np.unique(class_labels):
label_indicator = (class_labels == i)
class_log_likelihood.append(np.sum(all_log_likelihoods[label_indicator],axis = 0))
cluster_labels = np.argmax(class_log_likelihood, axis=1)
return {class_label: cluster_label for class_label, cluster_label in zip(np.unique(class_labels), cluster_labels)}
# print the map
class_cluster_map = label_cluster_mapping(gmm, X_pca_131, class_labels)
print("The label cluster mapping is: ", class_cluster_map)
The label cluster mapping is: {0: 5, 1: 0, 2: 4, 3: 1, 4: 3, 5: 1, 6: 6, 7: 3, 8: 1, 9: 3}
The keys of the mapping dictionary are the class label and the values of the mapping are best-fitting index of normal distributions:
$\cdot$ One interesting thing to notice is that digit 3, 5 and 8 are both mapped to label 1. In reality of hand-writting, these 2 digits do look alike for roughly their shapes.
$\cdot$ In addition, digit 4, 7 and 9 are clustered to label 3.
m=5 are clear in the background but blurred on the lighter part -- where the digits manifest. This tells the features captured are insufficient to represent distinct features of images of digits. And digits with similar rough shapes are more likely to be clustered as the identical digit.Analysis wrt the plot above is included in the explanation in 1.3.3
# log probs matrix: containing all the log-probs of data point, X_i to the kth mixture components
log_cluster_prob = log_probs(gmm, X_pca_131)
print(log_cluster_prob)
[[ -16.97570347 -18.7475566 -14.76573705 ... -19.31242479 -282.41979186 -11.23842018] [-114.66542145 -15.01411727 -19.51109961 ... -32.36062497 -200.37780552 -44.10846414] [-101.16633094 -15.31546101 -11.40404744 ... -33.75332819 -170.75535242 -57.92117636] ... [-148.20503636 -13.40809482 -17.53956278 ... -26.11280143 -132.32450628 -47.96001604] [-698.52436219 -12.81893972 -52.78271987 ... -30.4427656 -110.48773354 -63.4990667 ] [-226.9374572 -19.090202 -15.9754111 ... -12.50942501 -110.07120331 -34.46886479]]
# replot from 1.3.2
# visualize the space spanned
plt.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=cluster_labels, cmap="tab10")
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.colorbar(label='Cluster Labels')
plt.title('GMM Clustering on MNIST_train on PC1 and PC2, using cluster labels')
plt.show()
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(30, 10), sharex=True, sharey=True)
for i, ax in enumerate(axes.flatten()):
label_indicator = (class_labels == i)
cluster_probs_total = cluster_probabilities(gmm, X_pca_131[label_indicator, :])
cluster_probs = cluster_probs_total[:, class_cluster_map[i]]
cmap = plt.cm.get_cmap('viridis')
norm = plt.Normalize(vmin=cluster_probs.min(), vmax=cluster_probs.max())
sm = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
colors = sm.to_rgba(cluster_probs)
ax.scatter(X_pca_131[label_indicator, 0], X_pca_131[label_indicator, 1], c=colors)
ax.set_title('Digit {}'.format(i))
fig.colorbar(sm, ax=axes.ravel().tolist(), shrink=0.75, label='Cluster Probabilities')
plt.suptitle("Visualization of Each Class, Colored By the Cluster Probability of the Best-fitting Cluster")
plt.show()
The plot above gives the uncertainty of the class points assigned to the best-fitting cluster. The brighter the color, the more likely that the digits are assigned to the best-fitting cluster, based on 5 principle components.
To view the uncertainty:
By this, GMM model with 10 hidden components has a great uncertainty: Apart from digit 0, 1, 2 and 3, other digit classes either has a distribution with a wide range of values of cluster probabilities or there are many dark-color points which makes it hard to determine the dominant label.
From this, the clusters are highly stacked on the lower right part of the plot, giving low cluster probability to the corresponding areas of each class, eg. for the green dots in the scatter plot, it has orange, red and pink stacking on it, giving a dark color of cluster probability at the corresponding lower right part in the plot of digit 8. For clyster 5, the brown dots, it's individually and densely distributed around the area $(-25, 15)\cdot (5, 15)$. Therefore, it gives yellow dots to the corresponding region in the plot of digit 0.
def train_map_plots(hidden_components, n_its = 1000):
# train
gmm = train_gmm(X_pca_131, hidden_components)
# map
class_labels = MNIST_train[:1000, 0]
label_cluster_map = label_cluster_mapping(gmm, X_pca_131, class_labels)
print("Label-Cluster Index Map: ", label_cluster_map)
# plot
fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(30, 10))
# color by clusters
cluster_labels = predict(gmm, X_pca_131)
sc0 = ax0.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=cluster_labels, cmap="tab10")
ax0.set_xlabel('PC1')
ax0.set_ylabel('PC2')
ax0.set_title(f'Coloring Cluster Labels, for {hidden_components} components')
cbar = fig.colorbar(sc0, ax=ax0)
cbar.ax.set_ylabel('Cluster Labels')
# color by class labels
sc1 = ax1.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=class_labels, cmap="tab10")
ax1.set_xlabel('PC1')
ax1.set_ylabel('PC2')
ax1.set_title(f'Coloring Class Labels, for {hidden_components} components')
cbar = fig.colorbar(sc1, ax=ax1)
cbar.ax.set_ylabel('Class Labels')
plt.show()
# plot each individual class
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(30, 10), sharex=True, sharey=True)
for i, ax in enumerate(axes.flatten()):
mask = (class_labels == i)
cluster_probs_total = cluster_probabilities(gmm, X_pca_131[mask, :])
cluster_probs = cluster_probs_total[:, label_cluster_map[i]]
cmap = plt.cm.get_cmap('viridis')
norm = plt.Normalize(vmin=cluster_probs.min(), vmax=cluster_probs.max())
sm = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
colors = sm.to_rgba(cluster_probs)
ax.scatter(X_pca_131[mask, 0], X_pca_131[mask, 1], c=colors)
ax.set_title('Digit {}'.format(i))
plt.suptitle("Visualization of Each Class of Digit, Colored By the Cluster Probability of the Best-fitting Cluster Label")
fig.colorbar(sm, ax=axes.ravel().tolist(), shrink=0.75, label='Cluster Probabilities')
plt.show()
Retrain with 5 hidden components
hidden_components = 5
train_map_plots(hidden_components, n_its = 1000)
Converged at iteration 179!
Label-Cluster Index Map: {0: 2, 1: 4, 2: 3, 3: 3, 4: 1, 5: 2, 6: 3, 7: 1, 8: 2, 9: 3}
Retrain with 8 components
hidden_components = 8
train_map_plots(hidden_components, n_its = 1000)
Converged at iteration 305!
Label-Cluster Index Map: {0: 5, 1: 6, 2: 3, 3: 2, 4: 3, 5: 5, 6: 3, 7: 1, 8: 7, 9: 3}
Usually, the number of hidden components should be be at least the number of classes (in this case 10), by assuming each digit follows a distinct distribution. But if elements of classes are 'similar', less hidden components could lead to a good clustering outcome. From 1.2.4, we see that the reconstructed images, even at m=10, are super-vague, not to mention we only used m=5 in this section. In this case, the recontructed digit images are really blurred (also shown below) and are likely to bring highly-similar patterns and information to the clustering machine.
| 5 | 8 | 10 | |
|---|---|---|---|
| digits clustered to the same distribution: | 479 to 3, 358 to 1 | 2369 to 3, 05 to 2, 47 to 1 | 2469 to 3, 0 to 55 |
| numerous yellow points (Y/N) | YNN, NNN | YNNN, YN, NY | NNNN, Y |
In terms of distributions of clusters in the scatter plot, it can be seen obviously that for reduced hidden components to 5, the clusters are more well-separated, showing a reduced uncertainty.
One thing in common is that cluster probabilities are high for clusters densely overlapped. This is bacause points in highly dense areas mean samples are similar in these 2 dimensions. This will make clustering more difficult and lead to more uncertainty.
compliment to the analysis
# to show: reconstructed images for 5 principle components is highly blurred
test_constructed = test_original_pca @ eigenvectors_131 @ eigenvectors_131.T
plt.imshow(test_constructed[0].reshape(28,28))
plt.show()
# import and do the data exploration
gene_data = pd.read_csv("gene_expression_data.csv", decimal=",") # as a pandas data frame
# expressions
gene_expression = gene_data[gene_data.columns[:-1]].astype(float)
# type (labels)
gene_type = gene_data[gene_data.columns[-1]]
display(gene_data.head(5))
print("The type of the gene_data, espressions and types:", type(gene_data))
print("The type of espressions:", type(gene_expression))
print("The type of types:", type(gene_type))
print("The shape of the data is:", gene_data.shape)
| Gene 0 | Gene 1 | Gene 2 | Gene 3 | Gene 4 | Gene 5 | Gene 6 | Gene 7 | Gene 8 | Gene 9 | ... | Gene 86 | Gene 87 | Gene 88 | Gene 89 | Gene 90 | Gene 91 | Gene 92 | Gene 93 | Gene 94 | Type | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9.79608829288 | 0.591870870063 | 0.591870870063 | 0.0 | 11.4205708246 | 13.4537593388 | 4.41184651859 | 5.4123344238 | 10.7716132653 | 10.2256653627 | ... | 5.97436869818 | 8.08651310808 | 12.7277503154 | 15.2057168981 | 6.43811649662 | 6.412576621 | 0.0 | 6.81472985199 | 13.6181445741 | PRAD |
| 1 | 10.0704698332 | 0.0 | 0.0 | 0.0 | 13.085671621 | 14.5318626848 | 10.4622977655 | 9.83292639746 | 13.5203117438 | 13.9680457391 | ... | 0.0 | 0.0 | 11.1972044043 | 12.9939325894 | 10.8007462415 | 10.7498107801 | 0.0 | 11.445609809 | 0.0 | LUAD |
| 2 | 8.97091978401 | 0.0 | 0.452595434703 | 0.0 | 8.26311893627 | 9.75490753946 | 8.96454880885 | 9.94811313134 | 8.6937726804 | 8.7761105697 | ... | 3.90715987355 | 5.32410132 | 11.4870662364 | 13.3805963452 | 6.65623606704 | 10.2097335917 | 0.0 | 7.7488301787 | 12.759975541 | PRAD |
| 3 | 8.52461614952 | 1.03941918313 | 0.434881719407 | 0.0 | 10.7985204031 | 12.2630197299 | 7.44069478818 | 8.06234301354 | 8.80208333007 | 9.23748723755 | ... | 4.29608291884 | 6.95974697548 | 12.974638649 | 14.8918121772 | 6.03072451189 | 7.31564773826 | 0.434881719407 | 7.11792356209 | 12.3532764196 | PRAD |
| 4 | 8.04723845046 | 0.0 | 0.0 | 0.360982241369 | 12.2830101953 | 14.033758513 | 8.71918001723 | 8.83147193285 | 8.46207277429 | 8.21120206054 | ... | 0.0 | 0.0 | 11.3372372064 | 13.3900614488 | 5.98959318494 | 8.35967050637 | 0.0 | 6.32754545866 | 0.0 | BRCA |
5 rows × 96 columns
The type of the gene_data, espressions and types: <class 'pandas.core.frame.DataFrame'> The type of espressions: <class 'pandas.core.frame.DataFrame'> The type of types: <class 'pandas.core.series.Series'> The shape of the data is: (800, 96)
# standardize the data
def standardize(X):
"""Return a standardized dataset."""
if type(X) != np.ndarray:
X = X.to_numpy()
mean = np.mean(X, 0)
std = np.std(X, 0)
return (X - mean)/std
In this section, k-means clustering is implemeted for the gene expressions, where k is treated as a hyperparameter to tune. The highest Calinski-Harabasz index is used as a measure to find the optimal k and consistency of the clustering is assessed by homogeneity score.
Two things to notice:
In the initialization of labels in k-means algorithm, in order to avoid the situation where none of a sample is assigned to a certain cluster, we first randomly choose n-k samples from range 0 to k, and then add the labels with additional k labels, each of which is from a cluster. This guarantees at least one intially assigned sample for each cluster.
In computing $a_{ck}$ in homogeneity score, we should avoid it to be 0 as logarithm will be taken on it. For this case, simply skip it ie. add zero to the total sum, which will result into a slight underestimation in the true value of the total sum, compared to the result given by sklearn.
gene_expression = standardize(gene_expression)
def compute_centroids(X, k):
"""
Return the centroid for each cluster according to assignments.
Args:
X: data
k: number of clusters
"""
if type(X) != np.ndarray:
X = X.to_numpy()
n_samples, n_features = X.shape
# assign labels, each time with different initialization, but use a 'trick' to unsure at least 1 point in each cluster
labels = np.random.randint(low=0, high=k, size=n_samples-k)
labels = np.concatenate((labels, np.arange(k)))
random.shuffle(labels)
# initialization
centroids = np.zeros((k, n_features))
# compute the centroids of points with the same assigned label
for i in range(k):
centroids[i] = np.mean(X[labels==i], axis=0)
return centroids, labels
# k-means algorithm, from coding book
def k_clustering(X, k, max_iter, message=False):
"""
Return the updated centroids and labels for fixed k.
Args:
X: data set
k: number of clusters
max_iter: maximum number of iterations
"""
if type(X) != np.ndarray:
X = X.to_numpy()
difference = 0
new_labels = np.zeros(len(X))
# initialize centroids: each time different outcomes
centroids, labels = compute_centroids(X, k)
for i in range(max_iter):
if message==True:
print('Iteration:', i)
# distances: between data points and centroids
distances = np.array([np.linalg.norm(X - c, axis=1) for c in centroids])
# new_labels: computed by finding centroid with minimal distance
new_labels = np.argmin(distances, axis=0)
if (labels==new_labels).all():
# labels unchanged
labels = new_labels
if message==True:
print('Labels unchanged! Terminating k-means.')
break
else:
# labels changed
# difference: percentage of changed labels
difference = np.mean(labels!=new_labels)
if message==True:
print('%4f%% labels changed' % (difference * 100))
labels = new_labels
for c in range(k):
# update centroids by taking the mean over associated data points
if (labels == c).any():
centroids[c] = np.mean(X[labels==c], axis=0)
return (centroids, labels)
# Calinski-Harabasz index
def bcsm(X, k, centroids, labels):
"""
Return bcsm value.
Args:
X: gene_expression data set
k: number of clusters
centroids: the updated centroids after k_clustering
labels: the updates labels after k_clustering
"""
if type(X) != np.ndarray:
X = X.to_numpy()
# an array: number of points in each cluster
n_i = np.array([sum(labels==i) for i in range(k)])
# the centroid of all data points
z_tot = np.mean(X, axis=0)
# square distance between cluster centers and the total centroid
dis_sqr = np.array([np.linalg.norm(centroids[i] - z_tot)**2 for i in range(k)])
# return bcsm
return sum(n_i * dis_sqr)
def wcsm(X, k, centroids, labels):
"""
Return WCSM value.
Args:
X: gene_expression data set
k: number of clusters
centroids: the updated centroids after k_clustering
labels: the updates labels after k_clustering
"""
if type(X) != np.ndarray:
X = X.to_numpy()
wcsm_value = 0
for i in range(k):
wcsm_value += sum(np.array([np.linalg.norm(X[labels==i] - centroids[i])**2]))
return wcsm_value
def CH_k(X, k, centroids, labels):
"""
Return the wcsm quantity measure to assess the clustering of the data points.
The greater, the better the classification.
Args:
X: gene_expression data set
k: number of clusters
centroids: the updated centroids after k_clustering
labels: the updates labels after k_clustering
"""
if type(X) != np.ndarray:
X = X.to_numpy()
bcsm_value = bcsm(X, k, centroids, labels)
wcsm_value = wcsm(X, k, centroids, labels)
return bcsm_value*(X.shape[0]-k) / ((k-1)*wcsm_value)
# define elbow function
def norm_within_cluster_dis(X, cluster_labels):
"""
Return the w_c cost.
Args:
X: data set
cluster_labels: updated cluster labels
k: number of clusters.
"""
w_c = 0
for i in np.unique(cluster_labels):
# extract the corresponding elmts
cluster_elemts = X[cluster_labels==i, :]
for j in cluster_elemts:
for k in cluster_elemts:
w_c += 0.5* np.linalg.norm(j-k)**2 / len(cluster_elemts)
return w_c
# run for different k and 5 initializations for each k
np.random.seed(44)
k_range = 16 # k from 2 to k_range, as 'k-1' on the denominator
k_range_ch_index = np.array([])
for k_value in range(2, k_range):
ch_k = np.array([])
holding_l = []
# 5 different initializations for each k
for _ in range(5):
up_centroids, up_labels = k_clustering(gene_expression, k_value, max_iter=70, message=True)
# checked all converged with max_iter
ch_index = CH_k(gene_expression, k_value, up_centroids, up_labels)
ch_k = np.append(ch_k, ch_index)
k_range_ch_index = np.append(k_range_ch_index, np.mean(ch_k))
# print the optimal k corresponding ch index
print(f"The optimal k and CH index are: {2 + np.argmax(k_range_ch_index)} and {np.max(k_range_ch_index)}")
Iteration: 0 46.250000% labels changed Iteration: 1 1.500000% labels changed Iteration: 2 0.500000% labels changed Iteration: 3 Labels unchanged! Terminating k-means. Iteration: 0 46.375000% labels changed Iteration: 1 6.750000% labels changed Iteration: 2 3.625000% labels changed Iteration: 3 1.375000% labels changed Iteration: 4 0.500000% labels changed Iteration: 5 Labels unchanged! Terminating k-means. Iteration: 0 46.750000% labels changed Iteration: 1 19.625000% labels changed Iteration: 2 3.750000% labels changed Iteration: 3 1.250000% labels changed Iteration: 4 0.375000% labels changed Iteration: 5 Labels unchanged! Terminating k-means. Iteration: 0 45.125000% labels changed Iteration: 1 7.250000% labels changed Iteration: 2 7.000000% labels changed Iteration: 3 3.000000% labels changed Iteration: 4 0.750000% labels changed Iteration: 5 0.375000% labels changed Iteration: 6 Labels unchanged! Terminating k-means. Iteration: 0 46.500000% labels changed Iteration: 1 9.000000% labels changed Iteration: 2 2.375000% labels changed Iteration: 3 1.375000% labels changed Iteration: 4 1.000000% labels changed Iteration: 5 1.125000% labels changed Iteration: 6 1.125000% labels changed Iteration: 7 0.625000% labels changed Iteration: 8 0.375000% labels changed Iteration: 9 Labels unchanged! Terminating k-means. Iteration: 0 62.500000% labels changed Iteration: 1 10.625000% labels changed Iteration: 2 4.750000% labels changed Iteration: 3 4.500000% labels changed Iteration: 4 4.125000% labels changed Iteration: 5 5.625000% labels changed Iteration: 6 2.625000% labels changed Iteration: 7 0.500000% labels changed Iteration: 8 0.375000% labels changed Iteration: 9 0.500000% labels changed Iteration: 10 0.125000% labels changed Iteration: 11 Labels unchanged! Terminating k-means. Iteration: 0 62.750000% labels changed Iteration: 1 24.875000% labels changed Iteration: 2 13.000000% labels changed Iteration: 3 5.375000% labels changed Iteration: 4 1.875000% labels changed Iteration: 5 1.500000% labels changed Iteration: 6 1.250000% labels changed Iteration: 7 0.500000% labels changed Iteration: 8 0.625000% labels changed Iteration: 9 0.625000% labels changed Iteration: 10 0.375000% labels changed Iteration: 11 0.125000% labels changed Iteration: 12 0.375000% labels changed Iteration: 13 0.125000% labels changed Iteration: 14 Labels unchanged! Terminating k-means. Iteration: 0 65.500000% labels changed Iteration: 1 12.125000% labels changed Iteration: 2 2.750000% labels changed Iteration: 3 0.625000% labels changed Iteration: 4 0.375000% labels changed Iteration: 5 0.250000% labels changed Iteration: 6 0.250000% labels changed Iteration: 7 0.375000% labels changed Iteration: 8 0.125000% labels changed Iteration: 9 0.250000% labels changed Iteration: 10 0.625000% labels changed Iteration: 11 1.250000% labels changed Iteration: 12 2.250000% labels changed Iteration: 13 4.500000% labels changed Iteration: 14 7.125000% labels changed Iteration: 15 5.000000% labels changed Iteration: 16 2.250000% labels changed Iteration: 17 1.625000% labels changed Iteration: 18 1.250000% labels changed Iteration: 19 0.500000% labels changed Iteration: 20 0.625000% labels changed Iteration: 21 0.750000% labels changed Iteration: 22 0.375000% labels changed Iteration: 23 0.375000% labels changed Iteration: 24 0.125000% labels changed Iteration: 25 0.125000% labels changed Iteration: 26 Labels unchanged! Terminating k-means. Iteration: 0 63.125000% labels changed Iteration: 1 25.750000% labels changed Iteration: 2 7.375000% labels changed Iteration: 3 2.000000% labels changed Iteration: 4 1.500000% labels changed Iteration: 5 1.875000% labels changed Iteration: 6 1.625000% labels changed Iteration: 7 1.250000% labels changed Iteration: 8 0.875000% labels changed Iteration: 9 0.625000% labels changed Iteration: 10 0.875000% labels changed Iteration: 11 0.375000% labels changed Iteration: 12 0.375000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 0.125000% labels changed Iteration: 15 Labels unchanged! Terminating k-means. Iteration: 0 61.500000% labels changed Iteration: 1 13.375000% labels changed Iteration: 2 6.125000% labels changed Iteration: 3 2.000000% labels changed Iteration: 4 0.250000% labels changed Iteration: 5 Labels unchanged! Terminating k-means. Iteration: 0 70.750000% labels changed Iteration: 1 19.750000% labels changed Iteration: 2 5.625000% labels changed Iteration: 3 2.875000% labels changed Iteration: 4 3.875000% labels changed Iteration: 5 2.875000% labels changed Iteration: 6 3.125000% labels changed Iteration: 7 3.750000% labels changed Iteration: 8 2.625000% labels changed Iteration: 9 1.000000% labels changed Iteration: 10 1.375000% labels changed Iteration: 11 1.875000% labels changed Iteration: 12 1.375000% labels changed Iteration: 13 0.500000% labels changed Iteration: 14 0.750000% labels changed Iteration: 15 0.125000% labels changed Iteration: 16 0.250000% labels changed Iteration: 17 0.125000% labels changed Iteration: 18 0.250000% labels changed Iteration: 19 0.375000% labels changed Iteration: 20 0.125000% labels changed Iteration: 21 0.250000% labels changed Iteration: 22 0.250000% labels changed Iteration: 23 0.125000% labels changed Iteration: 24 0.125000% labels changed Iteration: 25 0.375000% labels changed Iteration: 26 0.250000% labels changed Iteration: 27 0.375000% labels changed Iteration: 28 0.750000% labels changed Iteration: 29 0.750000% labels changed Iteration: 30 1.125000% labels changed Iteration: 31 2.000000% labels changed Iteration: 32 4.250000% labels changed Iteration: 33 7.000000% labels changed Iteration: 34 6.250000% labels changed Iteration: 35 4.250000% labels changed Iteration: 36 4.500000% labels changed Iteration: 37 4.500000% labels changed Iteration: 38 2.375000% labels changed Iteration: 39 1.625000% labels changed Iteration: 40 1.375000% labels changed Iteration: 41 0.750000% labels changed Iteration: 42 1.125000% labels changed Iteration: 43 0.375000% labels changed Iteration: 44 Labels unchanged! Terminating k-means. Iteration: 0 71.875000% labels changed Iteration: 1 16.500000% labels changed Iteration: 2 5.750000% labels changed Iteration: 3 4.000000% labels changed Iteration: 4 4.750000% labels changed Iteration: 5 7.625000% labels changed Iteration: 6 5.625000% labels changed Iteration: 7 4.500000% labels changed Iteration: 8 2.375000% labels changed Iteration: 9 2.000000% labels changed Iteration: 10 0.750000% labels changed Iteration: 11 1.500000% labels changed Iteration: 12 0.875000% labels changed Iteration: 13 0.500000% labels changed Iteration: 14 0.250000% labels changed Iteration: 15 0.125000% labels changed Iteration: 16 Labels unchanged! Terminating k-means. Iteration: 0 69.125000% labels changed Iteration: 1 27.375000% labels changed Iteration: 2 14.000000% labels changed Iteration: 3 6.500000% labels changed Iteration: 4 2.250000% labels changed Iteration: 5 1.000000% labels changed Iteration: 6 0.750000% labels changed Iteration: 7 0.625000% labels changed Iteration: 8 0.750000% labels changed Iteration: 9 0.875000% labels changed Iteration: 10 1.000000% labels changed Iteration: 11 0.375000% labels changed Iteration: 12 0.125000% labels changed Iteration: 13 0.125000% labels changed Iteration: 14 Labels unchanged! Terminating k-means. Iteration: 0 69.625000% labels changed Iteration: 1 25.375000% labels changed Iteration: 2 9.125000% labels changed Iteration: 3 2.875000% labels changed Iteration: 4 2.000000% labels changed Iteration: 5 1.625000% labels changed Iteration: 6 1.000000% labels changed Iteration: 7 0.500000% labels changed Iteration: 8 0.250000% labels changed Iteration: 9 0.250000% labels changed Iteration: 10 Labels unchanged! Terminating k-means. Iteration: 0 71.500000% labels changed Iteration: 1 23.250000% labels changed Iteration: 2 14.875000% labels changed Iteration: 3 8.250000% labels changed Iteration: 4 5.875000% labels changed Iteration: 5 4.375000% labels changed Iteration: 6 2.750000% labels changed Iteration: 7 2.125000% labels changed Iteration: 8 1.625000% labels changed Iteration: 9 1.125000% labels changed Iteration: 10 0.500000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 Labels unchanged! Terminating k-means. Iteration: 0 72.750000% labels changed Iteration: 1 31.000000% labels changed Iteration: 2 8.500000% labels changed Iteration: 3 2.875000% labels changed Iteration: 4 2.000000% labels changed Iteration: 5 0.750000% labels changed Iteration: 6 0.125000% labels changed Iteration: 7 Labels unchanged! Terminating k-means. Iteration: 0 75.000000% labels changed Iteration: 1 19.250000% labels changed Iteration: 2 7.125000% labels changed Iteration: 3 1.750000% labels changed Iteration: 4 0.625000% labels changed Iteration: 5 0.375000% labels changed Iteration: 6 0.375000% labels changed Iteration: 7 0.375000% labels changed Iteration: 8 0.375000% labels changed Iteration: 9 0.625000% labels changed Iteration: 10 0.250000% labels changed Iteration: 11 0.125000% labels changed Iteration: 12 Labels unchanged! Terminating k-means. Iteration: 0 74.250000% labels changed Iteration: 1 20.125000% labels changed Iteration: 2 16.875000% labels changed Iteration: 3 9.125000% labels changed Iteration: 4 4.625000% labels changed Iteration: 5 2.375000% labels changed Iteration: 6 1.125000% labels changed Iteration: 7 0.125000% labels changed Iteration: 8 0.125000% labels changed Iteration: 9 Labels unchanged! Terminating k-means. Iteration: 0 75.125000% labels changed Iteration: 1 27.750000% labels changed Iteration: 2 11.750000% labels changed Iteration: 3 5.375000% labels changed Iteration: 4 5.250000% labels changed Iteration: 5 5.500000% labels changed Iteration: 6 4.000000% labels changed Iteration: 7 4.875000% labels changed Iteration: 8 3.375000% labels changed Iteration: 9 2.875000% labels changed Iteration: 10 1.750000% labels changed Iteration: 11 1.125000% labels changed Iteration: 12 1.250000% labels changed Iteration: 13 0.750000% labels changed Iteration: 14 0.750000% labels changed Iteration: 15 0.250000% labels changed Iteration: 16 0.125000% labels changed Iteration: 17 Labels unchanged! Terminating k-means. Iteration: 0 75.250000% labels changed Iteration: 1 26.500000% labels changed Iteration: 2 10.750000% labels changed Iteration: 3 8.875000% labels changed Iteration: 4 7.250000% labels changed Iteration: 5 5.500000% labels changed Iteration: 6 4.625000% labels changed Iteration: 7 4.625000% labels changed Iteration: 8 3.625000% labels changed Iteration: 9 2.000000% labels changed Iteration: 10 1.125000% labels changed Iteration: 11 0.750000% labels changed Iteration: 12 Labels unchanged! Terminating k-means. Iteration: 0 78.250000% labels changed Iteration: 1 25.875000% labels changed Iteration: 2 12.250000% labels changed Iteration: 3 6.000000% labels changed Iteration: 4 4.750000% labels changed Iteration: 5 3.875000% labels changed Iteration: 6 2.750000% labels changed Iteration: 7 1.375000% labels changed Iteration: 8 1.375000% labels changed Iteration: 9 1.625000% labels changed Iteration: 10 1.750000% labels changed Iteration: 11 1.750000% labels changed Iteration: 12 1.625000% labels changed Iteration: 13 1.375000% labels changed Iteration: 14 1.625000% labels changed Iteration: 15 1.250000% labels changed Iteration: 16 1.375000% labels changed Iteration: 17 1.500000% labels changed Iteration: 18 1.625000% labels changed Iteration: 19 0.750000% labels changed Iteration: 20 0.250000% labels changed Iteration: 21 0.250000% labels changed Iteration: 22 0.250000% labels changed Iteration: 23 0.125000% labels changed Iteration: 24 0.250000% labels changed Iteration: 25 0.375000% labels changed Iteration: 26 0.500000% labels changed Iteration: 27 0.500000% labels changed Iteration: 28 0.375000% labels changed Iteration: 29 0.250000% labels changed Iteration: 30 0.250000% labels changed Iteration: 31 0.250000% labels changed Iteration: 32 Labels unchanged! Terminating k-means. Iteration: 0 78.500000% labels changed Iteration: 1 27.750000% labels changed Iteration: 2 9.625000% labels changed Iteration: 3 5.625000% labels changed Iteration: 4 5.000000% labels changed Iteration: 5 4.000000% labels changed Iteration: 6 2.500000% labels changed Iteration: 7 2.500000% labels changed Iteration: 8 1.375000% labels changed Iteration: 9 1.250000% labels changed Iteration: 10 0.500000% labels changed Iteration: 11 0.375000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 Labels unchanged! Terminating k-means. Iteration: 0 78.250000% labels changed Iteration: 1 22.500000% labels changed Iteration: 2 12.500000% labels changed Iteration: 3 6.250000% labels changed Iteration: 4 4.250000% labels changed Iteration: 5 3.000000% labels changed Iteration: 6 2.000000% labels changed Iteration: 7 1.125000% labels changed Iteration: 8 0.375000% labels changed Iteration: 9 0.750000% labels changed Iteration: 10 0.125000% labels changed Iteration: 11 0.125000% labels changed Iteration: 12 Labels unchanged! Terminating k-means. Iteration: 0 77.250000% labels changed Iteration: 1 24.000000% labels changed Iteration: 2 8.750000% labels changed Iteration: 3 4.000000% labels changed Iteration: 4 3.125000% labels changed Iteration: 5 2.250000% labels changed Iteration: 6 0.625000% labels changed Iteration: 7 0.375000% labels changed Iteration: 8 0.375000% labels changed Iteration: 9 0.625000% labels changed Iteration: 10 0.625000% labels changed Iteration: 11 0.375000% labels changed Iteration: 12 0.500000% labels changed Iteration: 13 0.375000% labels changed Iteration: 14 0.375000% labels changed Iteration: 15 0.375000% labels changed Iteration: 16 Labels unchanged! Terminating k-means. Iteration: 0 76.875000% labels changed Iteration: 1 23.375000% labels changed Iteration: 2 7.375000% labels changed Iteration: 3 4.875000% labels changed Iteration: 4 4.250000% labels changed Iteration: 5 2.750000% labels changed Iteration: 6 2.625000% labels changed Iteration: 7 2.625000% labels changed Iteration: 8 2.500000% labels changed Iteration: 9 1.875000% labels changed Iteration: 10 1.000000% labels changed Iteration: 11 0.875000% labels changed Iteration: 12 1.000000% labels changed Iteration: 13 1.000000% labels changed Iteration: 14 1.125000% labels changed Iteration: 15 2.125000% labels changed Iteration: 16 1.500000% labels changed Iteration: 17 1.750000% labels changed Iteration: 18 2.000000% labels changed Iteration: 19 1.750000% labels changed Iteration: 20 1.625000% labels changed Iteration: 21 1.625000% labels changed Iteration: 22 1.750000% labels changed Iteration: 23 2.250000% labels changed Iteration: 24 1.625000% labels changed Iteration: 25 0.500000% labels changed Iteration: 26 0.375000% labels changed Iteration: 27 0.250000% labels changed Iteration: 28 Labels unchanged! Terminating k-means. Iteration: 0 80.500000% labels changed Iteration: 1 26.875000% labels changed Iteration: 2 12.125000% labels changed Iteration: 3 7.000000% labels changed Iteration: 4 3.625000% labels changed Iteration: 5 3.375000% labels changed Iteration: 6 3.375000% labels changed Iteration: 7 3.125000% labels changed Iteration: 8 2.375000% labels changed Iteration: 9 2.875000% labels changed Iteration: 10 2.625000% labels changed Iteration: 11 2.125000% labels changed Iteration: 12 2.875000% labels changed Iteration: 13 2.125000% labels changed Iteration: 14 1.625000% labels changed Iteration: 15 2.250000% labels changed Iteration: 16 1.375000% labels changed Iteration: 17 1.375000% labels changed Iteration: 18 1.000000% labels changed Iteration: 19 0.375000% labels changed Iteration: 20 0.375000% labels changed Iteration: 21 Labels unchanged! Terminating k-means. Iteration: 0 80.250000% labels changed Iteration: 1 13.250000% labels changed Iteration: 2 7.500000% labels changed Iteration: 3 5.625000% labels changed Iteration: 4 4.125000% labels changed Iteration: 5 2.750000% labels changed Iteration: 6 2.375000% labels changed Iteration: 7 1.750000% labels changed Iteration: 8 1.750000% labels changed Iteration: 9 0.875000% labels changed Iteration: 10 0.375000% labels changed Iteration: 11 Labels unchanged! Terminating k-means. Iteration: 0 78.500000% labels changed Iteration: 1 16.750000% labels changed Iteration: 2 7.250000% labels changed Iteration: 3 6.125000% labels changed Iteration: 4 5.875000% labels changed Iteration: 5 5.500000% labels changed Iteration: 6 4.500000% labels changed Iteration: 7 3.750000% labels changed Iteration: 8 2.750000% labels changed Iteration: 9 2.000000% labels changed Iteration: 10 2.000000% labels changed Iteration: 11 0.625000% labels changed Iteration: 12 0.125000% labels changed Iteration: 13 Labels unchanged! Terminating k-means. Iteration: 0 80.250000% labels changed Iteration: 1 27.750000% labels changed Iteration: 2 12.125000% labels changed Iteration: 3 8.875000% labels changed Iteration: 4 5.500000% labels changed Iteration: 5 6.000000% labels changed Iteration: 6 8.250000% labels changed Iteration: 7 7.250000% labels changed Iteration: 8 5.000000% labels changed Iteration: 9 3.750000% labels changed Iteration: 10 2.750000% labels changed Iteration: 11 2.625000% labels changed Iteration: 12 1.000000% labels changed Iteration: 13 1.375000% labels changed Iteration: 14 0.875000% labels changed Iteration: 15 0.500000% labels changed Iteration: 16 Labels unchanged! Terminating k-means. Iteration: 0 79.125000% labels changed Iteration: 1 29.625000% labels changed Iteration: 2 14.250000% labels changed Iteration: 3 9.875000% labels changed Iteration: 4 5.875000% labels changed Iteration: 5 4.250000% labels changed Iteration: 6 3.250000% labels changed Iteration: 7 3.375000% labels changed Iteration: 8 3.500000% labels changed Iteration: 9 3.125000% labels changed Iteration: 10 2.875000% labels changed Iteration: 11 2.875000% labels changed Iteration: 12 1.750000% labels changed Iteration: 13 1.875000% labels changed Iteration: 14 2.625000% labels changed Iteration: 15 1.625000% labels changed Iteration: 16 1.000000% labels changed Iteration: 17 0.750000% labels changed Iteration: 18 0.375000% labels changed Iteration: 19 0.125000% labels changed Iteration: 20 0.125000% labels changed Iteration: 21 Labels unchanged! Terminating k-means. Iteration: 0 84.000000% labels changed Iteration: 1 34.500000% labels changed Iteration: 2 10.750000% labels changed Iteration: 3 5.125000% labels changed Iteration: 4 3.500000% labels changed Iteration: 5 3.125000% labels changed Iteration: 6 3.000000% labels changed Iteration: 7 1.875000% labels changed Iteration: 8 1.125000% labels changed Iteration: 9 0.750000% labels changed Iteration: 10 0.250000% labels changed Iteration: 11 0.125000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 Labels unchanged! Terminating k-means. Iteration: 0 82.875000% labels changed Iteration: 1 21.375000% labels changed Iteration: 2 8.125000% labels changed Iteration: 3 3.875000% labels changed Iteration: 4 1.375000% labels changed Iteration: 5 0.500000% labels changed Iteration: 6 0.250000% labels changed Iteration: 7 Labels unchanged! Terminating k-means. Iteration: 0 82.875000% labels changed Iteration: 1 24.375000% labels changed Iteration: 2 9.000000% labels changed Iteration: 3 3.250000% labels changed Iteration: 4 3.125000% labels changed Iteration: 5 3.000000% labels changed Iteration: 6 2.250000% labels changed Iteration: 7 2.000000% labels changed Iteration: 8 1.500000% labels changed Iteration: 9 0.750000% labels changed Iteration: 10 1.000000% labels changed Iteration: 11 0.750000% labels changed Iteration: 12 0.875000% labels changed Iteration: 13 0.750000% labels changed Iteration: 14 0.875000% labels changed Iteration: 15 0.500000% labels changed Iteration: 16 0.750000% labels changed Iteration: 17 0.875000% labels changed Iteration: 18 0.500000% labels changed Iteration: 19 0.500000% labels changed Iteration: 20 0.625000% labels changed Iteration: 21 0.750000% labels changed Iteration: 22 0.250000% labels changed Iteration: 23 Labels unchanged! Terminating k-means. Iteration: 0 81.750000% labels changed Iteration: 1 25.125000% labels changed Iteration: 2 14.750000% labels changed Iteration: 3 10.375000% labels changed Iteration: 4 6.875000% labels changed Iteration: 5 4.500000% labels changed Iteration: 6 2.875000% labels changed Iteration: 7 2.250000% labels changed Iteration: 8 1.750000% labels changed Iteration: 9 1.875000% labels changed Iteration: 10 1.625000% labels changed Iteration: 11 1.000000% labels changed Iteration: 12 1.000000% labels changed Iteration: 13 0.750000% labels changed Iteration: 14 0.750000% labels changed Iteration: 15 0.500000% labels changed Iteration: 16 0.625000% labels changed Iteration: 17 0.750000% labels changed Iteration: 18 0.625000% labels changed Iteration: 19 0.375000% labels changed Iteration: 20 0.875000% labels changed Iteration: 21 0.250000% labels changed Iteration: 22 0.375000% labels changed Iteration: 23 0.375000% labels changed Iteration: 24 0.375000% labels changed Iteration: 25 0.500000% labels changed Iteration: 26 0.250000% labels changed Iteration: 27 0.125000% labels changed Iteration: 28 0.125000% labels changed Iteration: 29 0.250000% labels changed Iteration: 30 0.625000% labels changed Iteration: 31 0.500000% labels changed Iteration: 32 0.625000% labels changed Iteration: 33 0.500000% labels changed Iteration: 34 0.250000% labels changed Iteration: 35 0.250000% labels changed Iteration: 36 Labels unchanged! Terminating k-means. Iteration: 0 82.125000% labels changed Iteration: 1 29.750000% labels changed Iteration: 2 15.500000% labels changed Iteration: 3 7.375000% labels changed Iteration: 4 7.250000% labels changed Iteration: 5 6.875000% labels changed Iteration: 6 5.250000% labels changed Iteration: 7 4.625000% labels changed Iteration: 8 2.750000% labels changed Iteration: 9 1.250000% labels changed Iteration: 10 0.125000% labels changed Iteration: 11 Labels unchanged! Terminating k-means. Iteration: 0 82.000000% labels changed Iteration: 1 31.625000% labels changed Iteration: 2 10.125000% labels changed Iteration: 3 3.375000% labels changed Iteration: 4 3.000000% labels changed Iteration: 5 1.125000% labels changed Iteration: 6 1.250000% labels changed Iteration: 7 1.125000% labels changed Iteration: 8 0.625000% labels changed Iteration: 9 0.250000% labels changed Iteration: 10 0.125000% labels changed Iteration: 11 Labels unchanged! Terminating k-means. Iteration: 0 84.500000% labels changed Iteration: 1 23.625000% labels changed Iteration: 2 8.375000% labels changed Iteration: 3 6.000000% labels changed Iteration: 4 6.000000% labels changed Iteration: 5 4.125000% labels changed Iteration: 6 3.375000% labels changed Iteration: 7 2.750000% labels changed Iteration: 8 2.750000% labels changed Iteration: 9 2.375000% labels changed Iteration: 10 1.750000% labels changed Iteration: 11 1.500000% labels changed Iteration: 12 1.500000% labels changed Iteration: 13 0.875000% labels changed Iteration: 14 0.875000% labels changed Iteration: 15 0.500000% labels changed Iteration: 16 0.750000% labels changed Iteration: 17 0.750000% labels changed Iteration: 18 0.125000% labels changed Iteration: 19 0.250000% labels changed Iteration: 20 Labels unchanged! Terminating k-means. Iteration: 0 85.375000% labels changed Iteration: 1 28.250000% labels changed Iteration: 2 9.750000% labels changed Iteration: 3 3.625000% labels changed Iteration: 4 2.500000% labels changed Iteration: 5 1.250000% labels changed Iteration: 6 0.250000% labels changed Iteration: 7 0.375000% labels changed Iteration: 8 0.500000% labels changed Iteration: 9 0.625000% labels changed Iteration: 10 0.750000% labels changed Iteration: 11 0.750000% labels changed Iteration: 12 2.000000% labels changed Iteration: 13 1.500000% labels changed Iteration: 14 1.000000% labels changed Iteration: 15 0.750000% labels changed Iteration: 16 0.625000% labels changed Iteration: 17 0.500000% labels changed Iteration: 18 0.375000% labels changed Iteration: 19 0.500000% labels changed Iteration: 20 0.250000% labels changed Iteration: 21 Labels unchanged! Terminating k-means. Iteration: 0 83.750000% labels changed Iteration: 1 25.000000% labels changed Iteration: 2 13.125000% labels changed Iteration: 3 7.125000% labels changed Iteration: 4 4.375000% labels changed Iteration: 5 3.250000% labels changed Iteration: 6 2.250000% labels changed Iteration: 7 1.500000% labels changed Iteration: 8 0.750000% labels changed Iteration: 9 0.750000% labels changed Iteration: 10 0.250000% labels changed Iteration: 11 0.375000% labels changed Iteration: 12 0.375000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 0.500000% labels changed Iteration: 15 0.375000% labels changed Iteration: 16 0.500000% labels changed Iteration: 17 0.375000% labels changed Iteration: 18 0.500000% labels changed Iteration: 19 0.375000% labels changed Iteration: 20 0.750000% labels changed Iteration: 21 0.250000% labels changed Iteration: 22 0.500000% labels changed Iteration: 23 0.625000% labels changed Iteration: 24 0.875000% labels changed Iteration: 25 1.000000% labels changed Iteration: 26 1.000000% labels changed Iteration: 27 1.125000% labels changed Iteration: 28 0.500000% labels changed Iteration: 29 0.125000% labels changed Iteration: 30 0.125000% labels changed Iteration: 31 Labels unchanged! Terminating k-means. Iteration: 0 82.625000% labels changed Iteration: 1 31.250000% labels changed Iteration: 2 15.750000% labels changed Iteration: 3 8.125000% labels changed Iteration: 4 3.875000% labels changed Iteration: 5 2.500000% labels changed Iteration: 6 1.875000% labels changed Iteration: 7 1.625000% labels changed Iteration: 8 0.625000% labels changed Iteration: 9 0.500000% labels changed Iteration: 10 0.125000% labels changed Iteration: 11 Labels unchanged! Terminating k-means. Iteration: 0 84.000000% labels changed Iteration: 1 36.375000% labels changed Iteration: 2 15.375000% labels changed Iteration: 3 8.750000% labels changed Iteration: 4 4.500000% labels changed Iteration: 5 2.375000% labels changed Iteration: 6 2.125000% labels changed Iteration: 7 1.750000% labels changed Iteration: 8 1.625000% labels changed Iteration: 9 1.250000% labels changed Iteration: 10 1.125000% labels changed Iteration: 11 1.125000% labels changed Iteration: 12 1.625000% labels changed Iteration: 13 1.500000% labels changed Iteration: 14 1.125000% labels changed Iteration: 15 0.875000% labels changed Iteration: 16 0.750000% labels changed Iteration: 17 0.250000% labels changed Iteration: 18 Labels unchanged! Terminating k-means. Iteration: 0 85.000000% labels changed Iteration: 1 29.500000% labels changed Iteration: 2 10.250000% labels changed Iteration: 3 7.375000% labels changed Iteration: 4 6.625000% labels changed Iteration: 5 5.625000% labels changed Iteration: 6 4.250000% labels changed Iteration: 7 3.000000% labels changed Iteration: 8 2.125000% labels changed Iteration: 9 2.000000% labels changed Iteration: 10 0.750000% labels changed Iteration: 11 0.500000% labels changed Iteration: 12 Labels unchanged! Terminating k-means. Iteration: 0 84.375000% labels changed Iteration: 1 31.875000% labels changed Iteration: 2 11.500000% labels changed Iteration: 3 5.375000% labels changed Iteration: 4 3.500000% labels changed Iteration: 5 1.625000% labels changed Iteration: 6 1.000000% labels changed Iteration: 7 1.125000% labels changed Iteration: 8 0.875000% labels changed Iteration: 9 1.500000% labels changed Iteration: 10 0.750000% labels changed Iteration: 11 0.625000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 0.125000% labels changed Iteration: 14 0.125000% labels changed Iteration: 15 Labels unchanged! Terminating k-means. Iteration: 0 84.500000% labels changed Iteration: 1 28.625000% labels changed Iteration: 2 13.375000% labels changed Iteration: 3 6.250000% labels changed Iteration: 4 4.875000% labels changed Iteration: 5 4.750000% labels changed Iteration: 6 3.250000% labels changed Iteration: 7 3.125000% labels changed Iteration: 8 3.125000% labels changed Iteration: 9 2.875000% labels changed Iteration: 10 2.250000% labels changed Iteration: 11 2.375000% labels changed Iteration: 12 1.750000% labels changed Iteration: 13 2.250000% labels changed Iteration: 14 2.000000% labels changed Iteration: 15 1.625000% labels changed Iteration: 16 1.250000% labels changed Iteration: 17 1.375000% labels changed Iteration: 18 1.375000% labels changed Iteration: 19 0.625000% labels changed Iteration: 20 0.500000% labels changed Iteration: 21 0.250000% labels changed Iteration: 22 0.375000% labels changed Iteration: 23 0.500000% labels changed Iteration: 24 0.625000% labels changed Iteration: 25 0.750000% labels changed Iteration: 26 0.125000% labels changed Iteration: 27 0.125000% labels changed Iteration: 28 0.250000% labels changed Iteration: 29 0.125000% labels changed Iteration: 30 0.250000% labels changed Iteration: 31 0.250000% labels changed Iteration: 32 Labels unchanged! Terminating k-means. Iteration: 0 82.000000% labels changed Iteration: 1 28.750000% labels changed Iteration: 2 16.750000% labels changed Iteration: 3 8.875000% labels changed Iteration: 4 3.875000% labels changed Iteration: 5 2.125000% labels changed Iteration: 6 1.750000% labels changed Iteration: 7 1.500000% labels changed Iteration: 8 0.375000% labels changed Iteration: 9 0.125000% labels changed Iteration: 10 0.250000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 Labels unchanged! Terminating k-means. Iteration: 0 85.750000% labels changed Iteration: 1 20.375000% labels changed Iteration: 2 10.500000% labels changed Iteration: 3 6.625000% labels changed Iteration: 4 4.125000% labels changed Iteration: 5 2.375000% labels changed Iteration: 6 2.000000% labels changed Iteration: 7 1.750000% labels changed Iteration: 8 1.375000% labels changed Iteration: 9 1.500000% labels changed Iteration: 10 1.000000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 0.375000% labels changed Iteration: 14 0.500000% labels changed Iteration: 15 0.375000% labels changed Iteration: 16 0.375000% labels changed Iteration: 17 0.250000% labels changed Iteration: 18 0.250000% labels changed Iteration: 19 0.125000% labels changed Iteration: 20 Labels unchanged! Terminating k-means. Iteration: 0 85.250000% labels changed Iteration: 1 23.875000% labels changed Iteration: 2 9.625000% labels changed Iteration: 3 4.125000% labels changed Iteration: 4 2.500000% labels changed Iteration: 5 2.000000% labels changed Iteration: 6 1.250000% labels changed Iteration: 7 0.375000% labels changed Iteration: 8 0.125000% labels changed Iteration: 9 0.500000% labels changed Iteration: 10 0.375000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 0.500000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 0.250000% labels changed Iteration: 15 0.250000% labels changed Iteration: 16 Labels unchanged! Terminating k-means. Iteration: 0 85.250000% labels changed Iteration: 1 23.125000% labels changed Iteration: 2 10.875000% labels changed Iteration: 3 4.625000% labels changed Iteration: 4 3.750000% labels changed Iteration: 5 3.625000% labels changed Iteration: 6 3.000000% labels changed Iteration: 7 2.250000% labels changed Iteration: 8 2.750000% labels changed Iteration: 9 2.750000% labels changed Iteration: 10 3.125000% labels changed Iteration: 11 1.875000% labels changed Iteration: 12 1.750000% labels changed Iteration: 13 1.750000% labels changed Iteration: 14 0.875000% labels changed Iteration: 15 0.875000% labels changed Iteration: 16 0.250000% labels changed Iteration: 17 0.250000% labels changed Iteration: 18 0.500000% labels changed Iteration: 19 0.875000% labels changed Iteration: 20 0.375000% labels changed Iteration: 21 0.250000% labels changed Iteration: 22 0.250000% labels changed Iteration: 23 Labels unchanged! Terminating k-means. Iteration: 0 86.000000% labels changed Iteration: 1 32.125000% labels changed Iteration: 2 14.375000% labels changed Iteration: 3 7.250000% labels changed Iteration: 4 3.750000% labels changed Iteration: 5 3.125000% labels changed Iteration: 6 3.750000% labels changed Iteration: 7 3.125000% labels changed Iteration: 8 2.250000% labels changed Iteration: 9 1.000000% labels changed Iteration: 10 0.625000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 Labels unchanged! Terminating k-means. Iteration: 0 85.750000% labels changed Iteration: 1 18.875000% labels changed Iteration: 2 11.375000% labels changed Iteration: 3 8.875000% labels changed Iteration: 4 5.500000% labels changed Iteration: 5 3.875000% labels changed Iteration: 6 3.125000% labels changed Iteration: 7 4.375000% labels changed Iteration: 8 3.375000% labels changed Iteration: 9 2.875000% labels changed Iteration: 10 2.375000% labels changed Iteration: 11 1.500000% labels changed Iteration: 12 1.000000% labels changed Iteration: 13 0.750000% labels changed Iteration: 14 0.375000% labels changed Iteration: 15 0.125000% labels changed Iteration: 16 Labels unchanged! Terminating k-means. Iteration: 0 89.625000% labels changed Iteration: 1 37.500000% labels changed Iteration: 2 16.875000% labels changed Iteration: 3 9.625000% labels changed Iteration: 4 5.250000% labels changed Iteration: 5 3.125000% labels changed Iteration: 6 2.000000% labels changed Iteration: 7 0.750000% labels changed Iteration: 8 0.625000% labels changed Iteration: 9 0.750000% labels changed Iteration: 10 0.250000% labels changed Iteration: 11 0.500000% labels changed Iteration: 12 0.500000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 Labels unchanged! Terminating k-means. Iteration: 0 84.875000% labels changed Iteration: 1 40.375000% labels changed Iteration: 2 8.750000% labels changed Iteration: 3 3.875000% labels changed Iteration: 4 2.500000% labels changed Iteration: 5 1.750000% labels changed Iteration: 6 1.625000% labels changed Iteration: 7 1.250000% labels changed Iteration: 8 1.250000% labels changed Iteration: 9 1.250000% labels changed Iteration: 10 1.250000% labels changed Iteration: 11 1.250000% labels changed Iteration: 12 0.625000% labels changed Iteration: 13 0.750000% labels changed Iteration: 14 1.000000% labels changed Iteration: 15 1.125000% labels changed Iteration: 16 0.500000% labels changed Iteration: 17 0.500000% labels changed Iteration: 18 0.500000% labels changed Iteration: 19 0.750000% labels changed Iteration: 20 0.500000% labels changed Iteration: 21 0.500000% labels changed Iteration: 22 0.375000% labels changed Iteration: 23 Labels unchanged! Terminating k-means. Iteration: 0 87.250000% labels changed Iteration: 1 28.500000% labels changed Iteration: 2 8.375000% labels changed Iteration: 3 5.125000% labels changed Iteration: 4 5.375000% labels changed Iteration: 5 5.125000% labels changed Iteration: 6 3.500000% labels changed Iteration: 7 1.875000% labels changed Iteration: 8 1.500000% labels changed Iteration: 9 1.000000% labels changed Iteration: 10 0.500000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 0.375000% labels changed Iteration: 13 0.125000% labels changed Iteration: 14 Labels unchanged! Terminating k-means. Iteration: 0 85.000000% labels changed Iteration: 1 25.375000% labels changed Iteration: 2 10.875000% labels changed Iteration: 3 7.625000% labels changed Iteration: 4 5.750000% labels changed Iteration: 5 3.375000% labels changed Iteration: 6 2.000000% labels changed Iteration: 7 1.000000% labels changed Iteration: 8 0.750000% labels changed Iteration: 9 0.500000% labels changed Iteration: 10 0.250000% labels changed Iteration: 11 0.125000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 0.250000% labels changed Iteration: 15 0.375000% labels changed Iteration: 16 0.375000% labels changed Iteration: 17 0.125000% labels changed Iteration: 18 Labels unchanged! Terminating k-means. Iteration: 0 86.625000% labels changed Iteration: 1 30.250000% labels changed Iteration: 2 10.875000% labels changed Iteration: 3 6.500000% labels changed Iteration: 4 4.000000% labels changed Iteration: 5 2.750000% labels changed Iteration: 6 2.500000% labels changed Iteration: 7 2.375000% labels changed Iteration: 8 2.375000% labels changed Iteration: 9 2.000000% labels changed Iteration: 10 0.750000% labels changed Iteration: 11 0.500000% labels changed Iteration: 12 Labels unchanged! Terminating k-means. Iteration: 0 87.625000% labels changed Iteration: 1 26.125000% labels changed Iteration: 2 11.500000% labels changed Iteration: 3 6.625000% labels changed Iteration: 4 3.250000% labels changed Iteration: 5 1.875000% labels changed Iteration: 6 0.875000% labels changed Iteration: 7 0.875000% labels changed Iteration: 8 0.875000% labels changed Iteration: 9 0.625000% labels changed Iteration: 10 0.875000% labels changed Iteration: 11 0.750000% labels changed Iteration: 12 0.375000% labels changed Iteration: 13 0.375000% labels changed Iteration: 14 Labels unchanged! Terminating k-means. Iteration: 0 86.625000% labels changed Iteration: 1 31.000000% labels changed Iteration: 2 12.000000% labels changed Iteration: 3 5.875000% labels changed Iteration: 4 3.875000% labels changed Iteration: 5 2.125000% labels changed Iteration: 6 0.875000% labels changed Iteration: 7 0.750000% labels changed Iteration: 8 0.875000% labels changed Iteration: 9 1.125000% labels changed Iteration: 10 0.875000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 0.500000% labels changed Iteration: 14 0.375000% labels changed Iteration: 15 0.250000% labels changed Iteration: 16 0.375000% labels changed Iteration: 17 0.250000% labels changed Iteration: 18 0.250000% labels changed Iteration: 19 Labels unchanged! Terminating k-means. Iteration: 0 89.250000% labels changed Iteration: 1 33.500000% labels changed Iteration: 2 12.750000% labels changed Iteration: 3 9.125000% labels changed Iteration: 4 5.250000% labels changed Iteration: 5 3.500000% labels changed Iteration: 6 2.625000% labels changed Iteration: 7 1.625000% labels changed Iteration: 8 1.750000% labels changed Iteration: 9 1.375000% labels changed Iteration: 10 1.500000% labels changed Iteration: 11 1.250000% labels changed Iteration: 12 0.750000% labels changed Iteration: 13 1.000000% labels changed Iteration: 14 0.750000% labels changed Iteration: 15 0.625000% labels changed Iteration: 16 0.500000% labels changed Iteration: 17 0.375000% labels changed Iteration: 18 0.375000% labels changed Iteration: 19 Labels unchanged! Terminating k-means. Iteration: 0 87.250000% labels changed Iteration: 1 26.250000% labels changed Iteration: 2 14.125000% labels changed Iteration: 3 9.750000% labels changed Iteration: 4 6.500000% labels changed Iteration: 5 4.250000% labels changed Iteration: 6 2.250000% labels changed Iteration: 7 1.875000% labels changed Iteration: 8 1.625000% labels changed Iteration: 9 0.875000% labels changed Iteration: 10 1.000000% labels changed Iteration: 11 0.500000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 0.500000% labels changed Iteration: 15 0.875000% labels changed Iteration: 16 0.625000% labels changed Iteration: 17 0.250000% labels changed Iteration: 18 0.125000% labels changed Iteration: 19 Labels unchanged! Terminating k-means. Iteration: 0 88.375000% labels changed Iteration: 1 30.750000% labels changed Iteration: 2 11.500000% labels changed Iteration: 3 8.625000% labels changed Iteration: 4 5.000000% labels changed Iteration: 5 3.875000% labels changed Iteration: 6 3.625000% labels changed Iteration: 7 2.875000% labels changed Iteration: 8 2.250000% labels changed Iteration: 9 1.250000% labels changed Iteration: 10 0.875000% labels changed Iteration: 11 0.375000% labels changed Iteration: 12 0.250000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 0.250000% labels changed Iteration: 15 0.250000% labels changed Iteration: 16 0.125000% labels changed Iteration: 17 Labels unchanged! Terminating k-means. Iteration: 0 87.875000% labels changed Iteration: 1 22.625000% labels changed Iteration: 2 9.625000% labels changed Iteration: 3 6.125000% labels changed Iteration: 4 5.875000% labels changed Iteration: 5 6.500000% labels changed Iteration: 6 4.625000% labels changed Iteration: 7 2.000000% labels changed Iteration: 8 1.500000% labels changed Iteration: 9 0.375000% labels changed Iteration: 10 0.375000% labels changed Iteration: 11 0.500000% labels changed Iteration: 12 0.500000% labels changed Iteration: 13 0.500000% labels changed Iteration: 14 0.875000% labels changed Iteration: 15 0.625000% labels changed Iteration: 16 0.875000% labels changed Iteration: 17 1.250000% labels changed Iteration: 18 1.000000% labels changed Iteration: 19 0.500000% labels changed Iteration: 20 0.250000% labels changed Iteration: 21 0.250000% labels changed Iteration: 22 0.375000% labels changed Iteration: 23 0.250000% labels changed Iteration: 24 0.375000% labels changed Iteration: 25 0.125000% labels changed Iteration: 26 0.125000% labels changed Iteration: 27 0.125000% labels changed Iteration: 28 0.250000% labels changed Iteration: 29 0.250000% labels changed Iteration: 30 0.500000% labels changed Iteration: 31 0.375000% labels changed Iteration: 32 0.250000% labels changed Iteration: 33 0.375000% labels changed Iteration: 34 0.250000% labels changed Iteration: 35 0.625000% labels changed Iteration: 36 0.125000% labels changed Iteration: 37 0.125000% labels changed Iteration: 38 Labels unchanged! Terminating k-means. Iteration: 0 87.625000% labels changed Iteration: 1 29.125000% labels changed Iteration: 2 13.125000% labels changed Iteration: 3 7.750000% labels changed Iteration: 4 5.000000% labels changed Iteration: 5 2.875000% labels changed Iteration: 6 2.125000% labels changed Iteration: 7 1.625000% labels changed Iteration: 8 1.125000% labels changed Iteration: 9 0.750000% labels changed Iteration: 10 0.250000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 0.125000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 0.250000% labels changed Iteration: 15 Labels unchanged! Terminating k-means. Iteration: 0 86.625000% labels changed Iteration: 1 25.750000% labels changed Iteration: 2 13.875000% labels changed Iteration: 3 8.000000% labels changed Iteration: 4 5.000000% labels changed Iteration: 5 3.000000% labels changed Iteration: 6 2.750000% labels changed Iteration: 7 1.750000% labels changed Iteration: 8 1.750000% labels changed Iteration: 9 1.125000% labels changed Iteration: 10 0.750000% labels changed Iteration: 11 0.625000% labels changed Iteration: 12 0.500000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 0.375000% labels changed Iteration: 15 0.250000% labels changed Iteration: 16 0.250000% labels changed Iteration: 17 0.250000% labels changed Iteration: 18 0.250000% labels changed Iteration: 19 0.125000% labels changed Iteration: 20 Labels unchanged! Terminating k-means. Iteration: 0 89.125000% labels changed Iteration: 1 26.375000% labels changed Iteration: 2 13.875000% labels changed Iteration: 3 6.500000% labels changed Iteration: 4 3.625000% labels changed Iteration: 5 2.375000% labels changed Iteration: 6 2.000000% labels changed Iteration: 7 1.125000% labels changed Iteration: 8 0.625000% labels changed Iteration: 9 0.250000% labels changed Iteration: 10 Labels unchanged! Terminating k-means. Iteration: 0 86.750000% labels changed Iteration: 1 30.125000% labels changed Iteration: 2 10.625000% labels changed Iteration: 3 6.625000% labels changed Iteration: 4 4.875000% labels changed Iteration: 5 3.375000% labels changed Iteration: 6 3.000000% labels changed Iteration: 7 4.375000% labels changed Iteration: 8 4.750000% labels changed Iteration: 9 4.250000% labels changed Iteration: 10 3.375000% labels changed Iteration: 11 3.000000% labels changed Iteration: 12 1.625000% labels changed Iteration: 13 0.250000% labels changed Iteration: 14 Labels unchanged! Terminating k-means. Iteration: 0 89.625000% labels changed Iteration: 1 36.125000% labels changed Iteration: 2 12.000000% labels changed Iteration: 3 7.125000% labels changed Iteration: 4 6.250000% labels changed Iteration: 5 3.750000% labels changed Iteration: 6 2.125000% labels changed Iteration: 7 1.750000% labels changed Iteration: 8 1.500000% labels changed Iteration: 9 1.125000% labels changed Iteration: 10 1.250000% labels changed Iteration: 11 1.375000% labels changed Iteration: 12 1.875000% labels changed Iteration: 13 1.125000% labels changed Iteration: 14 1.000000% labels changed Iteration: 15 0.500000% labels changed Iteration: 16 0.125000% labels changed Iteration: 17 Labels unchanged! Terminating k-means. Iteration: 0 88.000000% labels changed Iteration: 1 34.375000% labels changed Iteration: 2 14.625000% labels changed Iteration: 3 7.250000% labels changed Iteration: 4 4.500000% labels changed Iteration: 5 4.000000% labels changed Iteration: 6 2.875000% labels changed Iteration: 7 2.125000% labels changed Iteration: 8 1.750000% labels changed Iteration: 9 1.375000% labels changed Iteration: 10 0.250000% labels changed Iteration: 11 0.500000% labels changed Iteration: 12 0.125000% labels changed Iteration: 13 0.125000% labels changed Iteration: 14 Labels unchanged! Terminating k-means. Iteration: 0 89.250000% labels changed Iteration: 1 28.125000% labels changed Iteration: 2 12.000000% labels changed Iteration: 3 5.750000% labels changed Iteration: 4 2.625000% labels changed Iteration: 5 3.000000% labels changed Iteration: 6 2.500000% labels changed Iteration: 7 0.500000% labels changed Iteration: 8 0.250000% labels changed Iteration: 9 0.125000% labels changed Iteration: 10 Labels unchanged! Terminating k-means. Iteration: 0 90.000000% labels changed Iteration: 1 28.875000% labels changed Iteration: 2 13.250000% labels changed Iteration: 3 8.250000% labels changed Iteration: 4 5.625000% labels changed Iteration: 5 4.125000% labels changed Iteration: 6 3.625000% labels changed Iteration: 7 1.375000% labels changed Iteration: 8 0.750000% labels changed Iteration: 9 Labels unchanged! Terminating k-means. Iteration: 0 87.625000% labels changed Iteration: 1 26.250000% labels changed Iteration: 2 9.375000% labels changed Iteration: 3 6.375000% labels changed Iteration: 4 4.375000% labels changed Iteration: 5 3.250000% labels changed Iteration: 6 2.375000% labels changed Iteration: 7 2.000000% labels changed Iteration: 8 1.000000% labels changed Iteration: 9 0.750000% labels changed Iteration: 10 0.500000% labels changed Iteration: 11 0.750000% labels changed Iteration: 12 1.250000% labels changed Iteration: 13 0.875000% labels changed Iteration: 14 1.125000% labels changed Iteration: 15 1.375000% labels changed Iteration: 16 1.500000% labels changed Iteration: 17 1.625000% labels changed Iteration: 18 0.875000% labels changed Iteration: 19 0.625000% labels changed Iteration: 20 0.375000% labels changed Iteration: 21 0.250000% labels changed Iteration: 22 0.125000% labels changed Iteration: 23 0.250000% labels changed Iteration: 24 Labels unchanged! Terminating k-means. The optimal k and CH index are: 4 and 233.02152942588236
# plotting
plt.plot(range(2, k_range), k_range_ch_index, color="blue")
plt.scatter(range(2, k_range), k_range_ch_index, color="blue")
plt.scatter(2 + np.argmax(k_range_ch_index), np.max(k_range_ch_index), color="red", label="optimal point")
plt.xlabel("The number of clusters k")
plt.ylabel("The value of Calinski-Harabasz index")
plt.title("Plot of Average Calinski-Harabasz Index Against Cluster Number k")
plt.legend(loc = 'upper right')
plt.show()
# find the size of clusters found by k-means with the optimal k
opt_k = 2 + np.argmax(k_range_ch_index)
opt_centroids, opt_labels = k_clustering(gene_expression, opt_k, max_iter=70, message=True)
Iteration: 0 71.500000% labels changed Iteration: 1 29.375000% labels changed Iteration: 2 11.000000% labels changed Iteration: 3 9.625000% labels changed Iteration: 4 7.375000% labels changed Iteration: 5 4.250000% labels changed Iteration: 6 2.875000% labels changed Iteration: 7 1.500000% labels changed Iteration: 8 1.250000% labels changed Iteration: 9 0.750000% labels changed Iteration: 10 0.500000% labels changed Iteration: 11 0.250000% labels changed Iteration: 12 Labels unchanged! Terminating k-means.
# construct a dictionary to sum of the size for each cluster
cluster_sizes = defaultdict(int)
for i in range(opt_k):
cluster_sizes[i] = np.sum(opt_labels==i)
print("The optimal value of k is: ", opt_k)
print("The cluster sizes found by the optimal k are: ", cluster_sizes)
The optimal value of k is: 4
The cluster sizes found by the optimal k are: defaultdict(<class 'int'>, {0: 303, 1: 135, 2: 221, 3: 141})
# optimal k from 2.1.1
opt_k = 2 + np.argmax(k_range_ch_index)
def h_score(opt_k, gene_type):
"""
Return the average H(C) score based on optimal k
Arg:
opt_k: optimal k from 2.1.1
gene_type: true label of the gene expressions.
"""
if type(gene_type) != np.ndarray:
gene_type = gene_type.to_numpy()
# initialization
a = defaultdict(int)
types = np.unique(gene_type)
n = len(gene_type)
h_c, h_ck = 0, 0
h_c_lis, h_ck_lis = [], []
for _ in range(5):
up_centroids, up_labels = k_clustering(gene_expression, opt_k, max_iter=20, message=False)
for c in types:
for k in range(opt_k):
a[(c, k)] = sum((gene_type==c) & (up_labels==k))
# to compute H(c)
for c in types:
sig_ack = sum(a[(c, k)] for k in range(opt_k))
h_c -= sig_ack/n * np.log(sig_ack/n)
h_c_lis.append(h_c)
# to compute H(c|k)
for k in range(opt_k):
sig_cck = sum(a[(j, k)] for j in types)
# skip if 0 is encounted in the log
for c in types:
if a[(c, k)] != 0:
h_ck -= a[(c, k)]/n * np.log(a[(c, k)]/sig_cck)
h_ck_lis.append(h_ck)
return 1 - np.mean(h_ck_lis)/np.mean(h_c_lis)
# print the homogeneity score
hc = h_score(opt_k, gene_type)
print("The homogeneity score is: ", hc)
The homogeneity score is: 0.5824509972004817
First to note: In defining h_score (homogeneity score), skipped some of the nan in a[(c, k)], meaning that none of the points in class c has beenclustered in cluster k. This will lead to a slight underestimation of the true homogeneity score.
By the set-up of the h_score, the higher the score, the better alignment with the original labels. From this perspective, the clustered labels are not highly consistent with the original labels.
Although CH_k gives the optimal clustering case, there could be some reasons for low h_score:
k-measn is dependent on initialization. In this very first stage, if several points from different true classes are assigned with the same cluster label, they are likely to form an individual cluster and will affect the outcome of clustering. If time permitts, several representative samples points should be chosen as initial centroids.
"CH_index" gives optimal clustering but necessarily the 'best' clustering. It has been shown to be sensitive to the shape of the clusters and the presence of outliers. If the initial assignments of labels give a considerable number of outliers, and they will be counted towards a certain cluster, despite the largest distance. This will consequently effect the results of centroids and thus the result of clutering. If time permits, checks for outliers and other structures of the label initialization should be implemented, and other measures of quality should be used in combination with this.
Note: the homogeneity score increases with k (shown below). This means that the greater the k, the better alignment of the clustering. But too large a k is likely to cause overfitting. Thus, the choice of k should be carefully chosen.
h_l = []
for k in range(2, 14):
h_l.append(h_score(k, gene_type))
plt.plot(range(2, 14), h_l, label="h_socre")
plt.scatter(range(2, 14), h_l)
plt.xlabel("cluster number k")
plt.ylabel("homogeneity score")
plt.title("h score against number of clusters")
plt.legend()
plt.show()
from sklearn.cluster import KMeans
from sklearn.metrics import homogeneity_score
types = np.unique(gene_type)
c_dict = {types[i]: i for i in range(len(types))}
print(c_dict)
# Building the clustering model
kmeans = KMeans(n_clusters=opt_k)
# Training the clustering mode
kmeans.fit(gene_expression)
# Storing the predicted Clustering labels
labels = kmeans.predict(gene_expression)
# Evaluating the performance
print(homogeneity_score(gene_type, labels))
{'BRCA': 0, 'COAD': 1, 'KIRC': 2, 'LUAD': 3, 'PRAD': 4}
0.586797136972071
For a network, centrality (representative nodes), communities and modularity(degree of connections within and between communities) are topics of concern. In this section, these topics are applied and discussed.
# recall the unnormalized data: gene_data
gene_expr_22 = gene_data[gene_data.columns[:-1]].astype(float).to_numpy()
#connected correlations
cor_mat = np.corrcoef(gene_expr_22, rowvar=False)
# adjacency matrix
np.fill_diagonal(cor_mat, 0)
cor_mat[np.abs(cor_mat) < 0.75] = 0
A = cor_mat
print(gene_expr_22.shape)
print(cor_mat.shape)
(800, 95) (95, 95)
plt.imshow(A)
plt.colorbar();
# degree centrality
degree = A.sum(axis=1)
sorted_index = np.argsort(degree)[::-1]
sorted_gene_expr_22 = gene_expr_22[sorted_index]
print("The index of the five topping ranking genes are: ", sorted_index[0:5])
print("The top 5 centralties are: ", degree[sorted_index[0:5]])
print("The top 5 ranking gene expressions are:")
for i in range(5):
print(gene_expr_22[i])
The index of the five topping ranking genes are: [17 41 16 81 90] The top 5 centralties are: [19.73839866 19.04945567 18.64028395 18.53179044 18.2670517 ] The top 5 ranking gene expressions are: [ 9.79608829 0.59187087 0.59187087 0. 11.42057082 13.45375934 4.41184652 5.41233442 10.77161327 10.22566536 10.03868584 5.51190126 5.77501102 10.92286682 5.6050409 6.05361315 8.40630281 7.72084618 5.74803716 7.47570912 7.15991169 5.39504909 2.47622613 3.92603738 1.01027857 13.83498451 13.87784967 13.7711593 10.67108988 0. 0. 9.52826851 8.82942148 7.82453875 12.21663674 9.84065835 5.01986645 5.90279957 10.1099351 5.93902906 5.99745726 5.63479653 7.3927639 4.26735601 2.47622613 7.67861422 4.97334034 5.04224159 3.26629182 0.59187087 12.2261382 10.91227561 11.36228906 11.84883025 12.08169979 10.39063115 14.19514917 8.83903683 3.01795753 6.29636379 5.16967652 17.17356979 18.52516138 14.12612366 6.7207436 7.27775234 10.6862729 0. 0. 0. 0. 0. 4.69693415 8.83951875 11.44045394 0. 0. 5.58992824 9.99201778 0. 4.80112243 6.89684097 6.89684097 2.01539052 8.57886703 9.1671095 5.9743687 8.08651311 12.72775032 15.2057169 6.4381165 6.41257662 0. 6.81472985 13.61814457] [10.07046983 0. 0. 0. 13.08567162 14.53186268 10.46229777 9.8329264 13.52031174 13.96804574 13.79901848 8.260228 9.65216737 9.09621968 8.18865313 9.70963528 11.96387506 11.28256722 9.51009952 10.51952814 11.39463418 9.56670291 8.87603187 1.32716997 0.58784501 14.76843764 14.23302043 14.57201329 11.66463839 0. 0. 7.29642987 0. 0. 11.92025608 0. 9.82974017 10.07282933 0. 0.58784501 0. 9.46716738 8.88319924 8.42131588 8.22548703 11.04734207 8.51354601 7.9453736 6.35739365 4.41775127 8.92761798 5.92151726 14.18870474 13.40419691 15.60740834 14.78221504 14.77833898 0. 8.78683156 9.78190226 10.4486016 0. 0. 0.81114217 0. 10.29676711 13.20786846 0. 0. 0. 0. 0. 6.47383236 12.64973894 12.02801309 0.32365829 0.32365829 8.29183375 7.79663669 0.32365829 9.40429432 10.54647008 0. 7.42867841 8.12118122 9.1224346 0. 0. 11.1972044 12.99393259 10.80074624 10.74981078 0. 11.44560981 0. ] [ 8.97091978 0. 0.45259543 0. 8.26311894 9.75490754 8.96454881 9.94811313 8.69377268 8.77611057 8.76759852 4.00972338 6.78232993 8.10219094 5.33731148 6.87130155 7.54758093 7.39801682 5.87905631 9.87243326 10.32239013 8.854005 8.46767483 3.75560333 2.53366301 11.63681491 10.90115351 11.42809318 8.61324801 0. 0. 6.29248177 8.80492477 8.40945034 6.16492301 10.48552746 9.22034447 5.94094193 9.80390508 0. 0. 6.08102852 6.13073529 7.75868294 6.58603511 6.61344178 4.23873337 4.70363231 2.6226727 2.53366301 7.05841391 4.50723013 8.70017612 7.95997876 10.36049597 11.0161397 13.34464267 9.93078759 8.68604874 8.82616393 8.85973666 14.8184224 16.05359715 10.80903655 6.8758292 6.37165643 9.71371643 0. 0. 0. 0. 0. 1.96484219 10.2450535 6.22301316 4.04235542 3.45276673 4.72389162 7.1435975 0.79659775 9.45885361 6.40967266 8.4974008 2.22801825 6.89829302 8.93189576 3.90715987 5.32410132 11.48706624 13.38059635 6.65623607 10.20973359 0. 7.74883018 12.75997554] [ 8.52461615 1.03941918 0.43488172 0. 10.7985204 12.26301973 7.44069479 8.06234301 8.80208333 9.23748724 9.35917193 5.80423946 5.10517519 8.00027048 4.07300613 5.73921895 7.84116804 7.28668712 5.74868932 7.61131995 8.13093602 7.41416091 6.17115878 4.13209148 2.80330971 13.46229777 13.20099973 13.24938608 10.39924588 0. 0. 6.56876026 10.07297667 8.76872689 9.41674054 10.07108689 7.47532812 5.9672424 9.28544846 0.76858664 0.43488172 5.72008963 6.86068972 5.5571803 5.32368351 6.63153214 3.89136086 4.98211286 2.80330971 2.47853181 10.60907714 7.78244133 10.21082961 9.91634527 12.92707792 9.6801151 12.92714461 10.28753925 6.56876026 7.13967413 8.12367654 17.37107895 18.37179366 13.77467383 5.52447433 6.60570915 9.71789915 0. 0. 0. 0. 0. 4.1430505 11.03974158 9.90527267 0. 0.76858664 3.18447074 8.04898812 1.26735601 6.88442698 6.07656551 7.99232092 4.13209148 7.55305325 8.96062754 4.29608292 6.95974698 12.97463865 14.89181218 6.03072451 7.31564774 0.43488172 7.11792356 12.35327642] [ 8.04723845 0. 0. 0.36098224 12.2830102 14.03375851 8.71918002 8.83147193 8.46207277 8.21120206 8.23725777 8.67169674 4.71700735 9.85755963 2.85877675 4.34557379 8.02536661 7.11465853 5.5386499 9.12604104 8.10498132 8.41117623 6.82021726 4.99935515 5.27829333 17.30907662 16.93580962 16.64340528 13.74022347 0. 0. 8.21258879 0. 0. 13.60890163 0. 8.14464805 5.40951992 0. 5.16854176 4.26253839 5.64082045 8.24790378 6.0589009 5.04975273 7.23387848 5.54288019 5.5386499 3.95400084 1.09565442 8.99392339 6.10729451 10.37130837 9.45944391 12.18371911 10.86885341 13.04320675 0. 6.79828386 8.69167374 8.62882826 1.58009723 0. 5.93701005 2.96762975 6.57128545 10.0700941 0. 0. 0. 0. 0. 5.48910616 9.27257613 14.03528024 0.64938553 0. 4.1046302 8.88363303 1.94212035 7.28999119 6.21611142 1.71114245 7.1839926 6.0215086 7.3386116 0. 0. 11.33723721 13.39006145 5.98959318 8.35967051 0. 6.32754546 0. ]
# symmetric normalized Laplacian function
def compute_l_norm(A):
"""
Return the normalized Laplacian.
Arg:
A: Adjacency matrix"""
weighted_degree = A.sum(axis=1)
D = np.diag(weighted_degree) # degree matrix D
# L_norm
weighted_degree_sqrt = 1.0 / np.sqrt(weighted_degree)
D_inv_sqrt = np.diag(weighted_degree_sqrt)
L_norm = np.eye(A.shape[0]) - D_inv_sqrt.dot(A.dot(D_inv_sqrt))
return L_norm
# eigen decomposition
L_norm = compute_l_norm(A)
eigenvals, eigenvecs = np.linalg.eigh(L_norm)
eigenvals
array([-1.25666763e-15, -1.22009119e-15, -5.86259751e-16, -2.82855538e-16,
-2.76761826e-16, -2.26146964e-16, -1.95100938e-16, -1.49770206e-16,
-9.98842225e-17, 4.49123519e-18, 6.31547765e-17, 2.77555756e-16,
3.24301287e-16, 3.33281305e-16, 3.96520475e-16, 6.14434034e-16,
7.29563077e-16, 1.13662133e-15, 4.40674341e-01, 7.09824816e-01,
8.09205368e-01, 8.74415544e-01, 9.49448789e-01, 9.93084679e-01,
1.00281813e+00, 1.01076222e+00, 1.02782108e+00, 1.04465433e+00,
1.04783531e+00, 1.04956517e+00, 1.05073999e+00, 1.05146914e+00,
1.05162853e+00, 1.05387651e+00, 1.05532375e+00, 1.05570299e+00,
1.05633592e+00, 1.05665956e+00, 1.05968737e+00, 1.07419326e+00,
1.07444423e+00, 1.07577484e+00, 1.07772569e+00, 1.07840825e+00,
1.07930189e+00, 1.07966193e+00, 1.08036281e+00, 1.08130455e+00,
1.08252594e+00, 1.08512021e+00, 1.08600827e+00, 1.09526489e+00,
1.11444086e+00, 1.12087585e+00, 1.13008862e+00, 1.13374849e+00,
1.13484766e+00, 1.13769825e+00, 1.14151998e+00, 1.15353425e+00,
1.16016346e+00, 1.16477260e+00, 1.16726495e+00, 1.16883048e+00,
1.16939703e+00, 1.16957148e+00, 1.18065418e+00, 1.18908346e+00,
1.19607574e+00, 1.19941793e+00, 1.21379412e+00, 1.21536563e+00,
1.21748346e+00, 1.21859141e+00, 1.22229890e+00, 1.22473975e+00,
1.25620720e+00, 1.26225785e+00, 1.26294794e+00, 1.26405150e+00,
1.27433950e+00, 1.45496166e+00, 1.49789651e+00, 1.50210349e+00,
1.57934555e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00,
2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00,
2.00000000e+00, 2.00000000e+00, 2.00000000e+00])
# Plotting the spectrum
r = 0
fig, ax = plt.subplots(1)
plt.plot(eigenvals)
for i in range(len(eigenvals)):
if abs(eigenvals[i]) < 1e-14:
plt.scatter(i, eigenvals[i], color="red")
r += 1
else:
plt.scatter(i, eigenvals[i], color="blue")
plt.scatter(0, eigenvals[0], color="red", label="rounded zero eigenvalues")
plt.scatter(len(eigenvals), eigenvals[-1], color="blue", label="non-zero eigenvalues")
# gap
plt.axvline(r-0.5, color="red", linestyle="--", label="gap=1e-14")
plt.xlabel("index")
plt.ylabel("Eigenvalues")
plt.title("Spectrum of eigenvalues.")
plt.legend()
plt.grid()
plt.show()
print("The number of zero eigenvalues is: r=", r)
The number of zero eigenvalues is: r= 18
Set the threshold as 1e-14, and round down eigenvalues smaller than it to 0. By doing so, it's obtained 18 eigenvalues are rounded down to 0 and marked red in the plot.
By lecture notes, the number of zero eigenvalues correspond to the components of the graph. Therefore, $r=18$ tells that there are 18 components consisting of the graph. In reality, it's possible that there is/are edges connecting components, but quite few.
In this section, when finding the optimal k using elbow method, a polynomial of order 8 is fitted to the within-cluster distance:
The loss curve is not smooth, thus determining k directly from the graph will lead to an inaccurate result.
The order of the polynomial is adjusted through trials to give an appropriate fit to the loss but without overfitting the data.
The optimal value of k is determined at the largest k where the gradient of the loss is less than -1:
# U
U = eigenvecs[:, :r]
# constructing T
row_norms = np.linalg.norm(U, axis=1)
D = np.diag(1/row_norms)
T = D.dot(U)
# define elbow function
def norm_within_cluster_dis(X, cluster_labels):
"""
Return the w_c cost.
Args:
X: data set
cluster_labels: updated cluster labels
k: number of clusters.
"""
w_c = 0
for i in np.unique(cluster_labels):
# extract the corresponding elmts
cluster_elemts = X[cluster_labels==i, :]
for j in cluster_elemts:
for k in cluster_elemts:
w_c += 0.5* np.linalg.norm(j-k)**2 / len(cluster_elemts)
return w_c
# recall the k-means clustering algorithm
np.random.seed(42)
k_range = range(2, 40)
w_c = []
times = 100
# implement clustering with each k for 100 times and choose the one with minimal cost
for k in k_range:
holding_l = []
for i in range(times):
up_centroids, up_labels = k_clustering(T, k, max_iter=100, message=False)
holding_l.append(norm_within_cluster_dis(T, up_labels))
wc_cost = np.min(holding_l)
w_c.append(wc_cost)
# the within-cluster distance cost
print(w_c)
[64.61111111110448, 51.07142857143121, 42.851063829786696, 36.34999999999901, 33.05128205128229, 26.324110671936715, 23.780219780219834, 21.303030303030248, 17.76923076923074, 15.000000000000021, 13.111111111111118, 11.399999999999991, 8.857142857142863, 6.400000000000001, 4.4, 4.0, 3.9999999999999982, 3.1111111111111125, 1.9999999999999991, 2.0, 9.4374449716689e-29, 1.9999999999999996, 9.335365708757868e-29, 8.914833625882453e-29, 8.627403593411206e-29, 8.732139564492352e-29, 7.975464815141715e-29, 8.544480761022826e-29, 7.987259597634385e-29, 7.971117744695417e-29, 8.244491999274519e-29, 7.560213599202927e-29, 7.12001438346075e-29, 7.246370752697711e-29, 6.970868871646457e-29, 6.955285149371518e-29, 6.244727734560932e-29, 6.136287568266725e-29]
For elbow method, we want to tradeoff between k and the cost-- a k such that the cost is low while the k is small as well.
From lecture notes, the elbow curve is deceasing fast and then the decreasing rate slows down and finally remains almost stable. The elbow point is set at the 'middle' change-phase ie. the region where the curve's decreasing rate is slowing down.
Therefore, a tolerance eps should be set and adjusted through trials. (Intuitively from the plot, we should expect that it falls between 12-20.)
eps = 2 # tolerance
for i in range(len(w_c)):
if w_c[i] <= eps:
elbow_k = i + 2
print("The elbow k is: ", elbow_k)
print("The cost is: ", w_c[i])
break
The elbow k is: 20 The cost is: 1.9999999999999991
plt.plot(k_range, w_c)
plt.scatter(k_range, w_c, label='Distance Values')
plt.scatter(elbow_k, w_c[elbow_k-2], color="red", label="the elbow k")
plt.legend()
plt.xlabel("number of clusters k")
plt.ylabel("within cluster distance")
plt.title("The within cluster distance against cluster number k")
plt.show()
# obtain clustering for the optimal k
np.random.seed(4)
centroids, labels = k_clustering(T, elbow_k, max_iter=100, message=True)
# construct a dictionary to hold the indexes of points for each label
cluster_dict = defaultdict(list)
for i in range(elbow_k):
cluster_dict[i].extend(list(np.where(labels==i)[0]))
print("The value of k at the elbow point is: ", elbow_k)
print("The clustering by elbow k: ", cluster_dict)
Iteration: 0
72.631579% labels changed
Iteration: 1
Labels unchanged! Terminating k-means.
The value of k at the elbow point is: 20
The clustering by elbow k: defaultdict(<class 'list'>, {0: [], 1: [84, 85], 2: [49, 73], 3: [29, 30, 50, 51], 4: [6, 7, 19, 20, 21, 22, 36, 43, 44, 55, 58, 59, 60, 77, 80, 91], 5: [8, 9, 10, 12, 14, 15, 16, 17, 18, 37, 41, 45, 46, 47, 48, 52, 53, 54, 65, 66, 81, 90, 93], 6: [1, 2, 3, 23, 24, 39, 40, 79, 92], 7: [0, 56, 75, 76], 8: [], 9: [25, 26, 27, 28, 34, 74], 10: [], 11: [42, 61, 62, 63, 64, 72, 86, 87], 12: [4, 5], 13: [], 14: [11, 13, 31, 78, 83], 15: [67, 68, 69, 70, 71], 16: [], 17: [32, 33, 35, 38, 57, 82, 94], 18: [88, 89], 19: []})
# define a function to catch the size of the clusters of T
def cluster_size(cluster_dict):
"""
Return the sizes of each cluster of T.
"""
size_dict = defaultdict(int)
for i in cluster_dict:
size_dict[i] = len(cluster_dict[i])
return size_dict
size_dict = cluster_size(cluster_dict)
print("The size of each cluster by elbow k is: ", size_dict)
The size of each cluster by elbow k is: defaultdict(<class 'int'>, {0: 0, 1: 2, 2: 2, 3: 4, 4: 16, 5: 23, 6: 9, 7: 4, 8: 0, 9: 6, 10: 0, 11: 8, 12: 2, 13: 0, 14: 5, 15: 5, 16: 0, 17: 7, 18: 2, 19: 0})
By seeting a tolerance eps, elbow k is found at k==20. Each column of T is an eigenvector corresponds to a unique subgraph of the original network. Therefore, clustering T should give an optimal k close to r=18. The k found by elbow problem makes sense from this perspective.
The construction of T is equivalent to projecting the nodes into a lower dimensional space (dim=18 in this case)
Similar to clustering implemented in coursework 1, clusters found have smallest total within-cluster distance and greatest total inter-cluster diatances. And these two criteria are substituted by connectivity: as many as within-cluster edges and as less as inter-cluster edges. Implementing k-means clustering, we found that some of the labels have zero samples clustered in that cluster, and this is because the update of labels with the change of centroids in each iteration of k-means algorithm.
The clustering of T provides an insight into the modularity(how clusters connect) of the network. For T:
$\cdot$ each row represents the behavior of a node in these zero eigen-spaces. By clustering results, there are 14 clusters of similar behavior of the nodes, where each cluster is densely connedted within itself and sparsely connected with other clusters. This gives an insight of which of the features of the genes are more connected and which are not.
By 2.1.3, there are 18 clusters. But for the result of clustering of T, there are 14 non-empty clusters and 6 empty clusters. This means that some of clusters must have been merged to form a bigger one in kmeans clustering. Recall the process of updating centroids of each cluster and reassigning labels--bigger clusters tend to drag small clusters and merge them as they are a main force of dragging the centoids. In terms of the graph structure, some small clusters are more connected to bigger ones, eg. the cluster labelled 1 can be more likely to be more attached to the cluster labelled 4.
# get the largest cluster and corresponding node indices
largest_cluster = list(size_dict.keys())[np.argmax(list(size_dict.values()))]
largest_cluster_indices = cluster_dict[largest_cluster]
# get subgraph and corresponding adjcency matrix
mask = (labels == largest_cluster)
sub_A = A[mask,:][:,mask]
# corresponding indexes of the original network
mask_indx = np.where(mask==True)
# get sub_Lagrangian
sub_L_norm = compute_l_norm(sub_A)
sub_eigenvals, sub_eigenvecs = np.linalg.eigh(sub_L_norm)
# perform binary spectral partition
second_index = np.where(sub_eigenvals>1e-7)[0][0] # first nontrivial index
spectral_partition = sub_eigenvecs[:,second_index]
spectral_partition[spectral_partition<0] = 0
spectral_partition[spectral_partition>0] = 1
p1 = np.where(spectral_partition == 0)[0]
p2 = np.where(spectral_partition == 1)[0]
n1 = len(p1)
n2 = len(p2)
# Define the ratios for the subplots
width_ratios = [1, 1, 1.8]
# Create the figure and the grid
fig = plt.figure(figsize=(25, 15))
gs = gridspec.GridSpec(1, 3, width_ratios=width_ratios)
# first plot
ax0 = plt.subplot(gs[0])
im0 = ax0.imshow(sub_A[p1,:][:,p1], cmap='viridis', vmin=-1, vmax=1)
ax0.set_xticks(np.arange(n1))
ax0.set_yticks(np.arange(n1))
ax0.set_xticklabels(np.array(largest_cluster_indices)[p1])
ax0.set_yticklabels(np.array(largest_cluster_indices)[p1])
ax0.set_title('Magnitude of Network Links within Partition 1')
# second plot
ax1 = plt.subplot(gs[1])
im1 = ax1.imshow(sub_A[p2,:][:,p2], cmap='viridis', vmin=-1, vmax=1)
ax1.set_xticks(np.arange(n2))
ax1.set_yticks(np.arange(n2))
ax1.set_xticklabels(np.array(largest_cluster_indices)[p2])
ax1.set_yticklabels(np.array(largest_cluster_indices)[p2])
ax1.set_title('Magnitude of Network Links within Partition 2')
# third plot
ax2 = plt.subplot(gs[2])
im2 = ax2.imshow(sub_A[p1,:][:,p2], cmap='viridis', vmin=-1, vmax=1)
ax2.set_xticks(np.arange(n2))
ax2.set_yticks(np.arange(n1))
ax2.set_xticklabels(np.array(largest_cluster_indices)[p2])
ax2.set_yticklabels(np.array(largest_cluster_indices)[p1])
ax2.set_title('Magnitude of the Network Links across Partitions')
# Add a color bar
cbar2 = fig.colorbar(im2, ax=ax2, shrink=0.4)
cbar2.set_label('Magnitude of Lines')
# Show the plot
plt.show()
A corresponding network graph is plotted below.
within partitions:
across partitions:
# drawing networkx graph of sub_A, and coloring partitions with 2 colors
g = nx.Graph(sub_A)
colored_nodes = p1
node_colors = ['red' if node in colored_nodes else 'blue' for node in g.nodes()]
nx.draw(g, node_size=30, node_color=node_colors)
# degree centrality
sub_degree = sub_A.sum(axis=1)
sub_sorted_index = np.argsort(sub_degree)[::-1]
sub_sorted_gene_expr_22 = gene_expr_22[sub_sorted_index]
print("The index of the five topping ranking genes are: ", mask_indx[0][sub_sorted_index[0:5]])
print("The top 5 centralities are: ", sub_degree[sub_sorted_index[0:5]])
print("The top 5 ranking gene expressions are:",)
for i in range(5):
print(sub_sorted_gene_expr_22[i])
The index of the five topping ranking genes are: [17 41 16 81 90] The top 5 centralities are: [19.73839866 19.04945567 18.64028395 18.53179044 18.2670517 ] The top 5 ranking gene expressions are: [ 8.92200751 1.65526023 0.44180215 0. 12.06016901 13.67437988 4.39674177 4.42107515 11.38857703 10.95483607 10.92828889 8.57963273 7.98435579 11.33975538 7.18736197 8.85325653 10.24577848 9.43419452 7.63761749 5.20278991 5.1010365 5.31091447 2.96487914 2.89710486 1.48104056 15.29502989 16.03244183 14.99020592 12.31127434 0. 0. 9.66501785 7.91964995 6.48955899 15.18833385 10.71144328 3.41602901 8.27041214 9.32156512 6.09506783 4.76317252 8.31978049 8.32314035 4.62336356 3.63985727 8.95049758 4.28954678 4.89062179 3.0296119 0.7795539 9.9098121 7.53265203 12.89064048 12.39811037 13.96305862 9.58934029 13.53431504 9.87500439 3.0296119 4.7820835 4.53694187 15.74874843 17.60456047 14.13791172 7.20498189 8.82161775 11.65521317 0. 0. 0. 0. 0. 5.19462292 10.23099312 15.05184618 0. 0. 3.55444151 10.49176293 1.9509909 4.51450104 8.16629868 8.70011022 6.15432599 5.75562469 7.39119203 3.31576909 5.94372914 11.72023571 13.80957723 7.48439596 4.26850925 0. 9.43867128 12.81238792] [ 8.11540786 0. 0. 0. 11.44178743 13.13296716 8.05858708 8.29733359 10.5051069 10.32072159 10.13971761 6.79013342 9.19840821 10.11150086 7.84939885 9.59669107 10.39006145 9.82915823 7.10945459 9.20889031 8.13752402 8.50382971 5.93018144 1.83608571 1.65168318 16.53078428 16.06013438 16.53424512 12.50169788 0. 0. 8.97098016 0. 0. 14.05376471 0. 7.47864021 8.73172632 0. 5.93018144 4.87157755 9.33482523 8.79097073 4.9335632 4.51403386 8.31102168 5.84672648 6.10198197 4.40156221 5.76951787 8.75050267 5.9994431 12.83481865 12.30149048 13.76738789 10.97626338 12.64077816 0. 6.61835026 8.49529143 8.08841725 2.61842662 3.12288803 4.61836788 0.51440004 8.40153661 11.55723871 0. 0. 0. 0. 0. 5.89009887 14.34901163 13.52010299 0. 0. 4.71566278 10.18607709 3.19214681 7.84671895 8.71711431 0. 4.31105257 6.12872644 7.40423045 0. 0. 11.37563681 13.30660288 8.35132355 8.02101326 0. 10.14791719 0. ] [ 9.97363997 6.8802937 5.68310691 0. 11.61823407 13.9518584 5.29798006 4.20064559 11.16863451 10.96875318 10.40217003 8.33689033 6.85273548 11.94623374 5.42318148 7.08838545 9.4623264 8.47163457 6.70643385 6.81895279 5.91398082 5.85963169 3.70458428 7.30970371 7.01568266 15.17014023 15.5444936 15.36038417 12.44458529 0. 0. 10.70656509 6.76674091 7.51841711 12.76486745 10.61807344 5.29798006 5.85963169 8.34877683 9.32666767 10.64202461 7.65552357 8.95335837 4.79441067 3.17392693 8.57307151 6.70025937 7.15308584 5.75632681 2.76464339 9.99585193 8.44604527 12.11841878 11.52265509 12.32970666 7.80689757 14.71886461 9.48404546 4.56909002 6.63056419 6.16766301 0. 0. 0. 0. 8.49856203 11.78353032 0. 0. 0. 0. 0. 6.8886215 10.39599461 11.78316605 0. 0. 2.32553033 10.94708897 4.7710457 3.75310173 6.96510317 7.50785805 5.56579517 6.59107947 8.02012434 0. 0. 11.16275007 13.48954695 7.12571309 5.84851278 0. 7.16657486 12.34301356] [10.10335338 0. 0. 0. 12.02755703 13.95778223 10.62569058 9.96159873 8.42283848 9.34314573 9.2134931 6.2207962 5.24691347 10.15564046 2.7654711 5.1618232 8.9703619 9.0524322 7.54093394 8.64549993 8.77051524 8.8129923 8.77171064 0.78659636 0. 16.93219168 17.2118978 17.00718061 13.89325412 0. 0. 8.2290495 0. 0. 14.07698257 0. 10.85953479 6.89686518 0. 5.0399499 3.98748459 8.63895715 7.07066839 7.97843019 6.49458842 9.91143693 6.89168651 6.37366909 4.95819773 1.82268936 4.97492389 1.96343713 11.14972172 9.2405457 12.5257141 9.55806861 14.49403678 0. 9.02458841 8.80949805 9.46802118 0.44625623 0. 6.05765388 2.31825899 8.31202071 12.08188616 0. 0. 0. 0. 0. 3.53908435 10.11297374 15.3735483 0. 0. 0.78659636 8.43648669 1.06170708 8.72303971 8.4678419 0.78659636 3.77472382 4.20157113 5.43115157 0. 0.44625623 11.30287878 13.46658634 7.78807881 9.16694386 0. 8.52432623 0. ] [ 8.69782594 3.51765364 3.86295716 8.78777199 12.33524239 13.99592611 8.04792549 9.37627519 12.82544579 13.02612398 12.83938293 8.57285557 7.80233541 9.14525457 6.80379532 8.36665006 11.64664421 10.41460068 8.88902628 8.77635407 8.38179816 7.94515702 7.02577078 5.50184674 4.4013301 14.44136209 14.23803894 14.70011997 12.05164739 0. 0. 7.64660191 8.79533513 6.26455144 12.59471415 10.78215462 7.58974851 9.58193304 10.32730562 2.1648828 2.1648828 8.85565651 10.49858996 6.1033025 5.4585559 10.04753294 8.47262992 8.23195075 6.58853625 3.41666412 5.75593174 3.06277796 12.74262944 12.09926232 14.42464527 11.2576648 13.47460589 9.04519721 6.51715441 7.88263082 8.74545248 1.11196609 0. 2.39467876 0.4720718 10.01820018 13.22198278 0. 0. 0. 0. 0. 7.88028758 11.61066961 12.96694158 1.11196609 0. 4.75161368 8.47107363 4.2627263 6.62294499 9.8383231 8.14385766 7.10093561 6.51104418 7.85819241 0. 0. 11.85066349 13.89602014 9.93506631 8.88075592 4.34747432 10.40391742 14.74470772]
print("The index of the top 5 ranking genes are: ", sorted_index[0:5])
print("The centralities of top 5 gene expressions of the A are:", degree[sorted_index[0:5]])
print("The index of the top 5 ranking genes are: ", mask_indx[0][sub_sorted_index[0:5]])
print("The centralities of top 5 gene expressions of the sub_A are:", sub_degree[sub_sorted_index[0:5]])
The index of the top 5 ranking genes are: [17 41 16 81 90] The centralities of top 5 gene expressions of the A are: [19.73839866 19.04945567 18.64028395 18.53179044 18.2670517 ] The index of the top 5 ranking genes are: [17 41 16 81 90] The centralities of top 5 gene expressions of the sub_A are: [19.73839866 19.04945567 18.64028395 18.53179044 18.2670517 ]
| 17 | 41 | 16 | 81 | 90 | |
|---|---|---|---|---|---|
| centrality of original graph | 19.73839866 | 19.04945567 | 18.64028395 | 18.53179044 | 18.2670517 |
| centrality of the sub graph | 19.73839866 | 19.04945567 | 18.64028395 | 18.53179044 | 18.2670517 |
By the printed result above, the top 5 degree centrality of the sub graph remain identical to the orignal network.
This makes sense by the way we found the elbow k, and the biggest cluster:
The centrality of a network indicates the importance of each node: the higher the centrality, the more representative of a node. Therefore, the top 5 centralities of the subgraph is expected to be identical to be the global top 5 centralities of the original network.
Please delete this section if you are not a master student